-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
200 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# The datasets | ||
import matplotlib.pyplot as plt | ||
import urllib.request | ||
import io | ||
import glob | ||
import numpy as np | ||
import scipy.io | ||
import scipy.ndimage | ||
import re | ||
from os.path import expanduser | ||
|
||
def load_image(filename): | ||
return scipy.ndimage.imread(filename, mode='RGB') | ||
|
||
def load_image_url(url): | ||
ext = url.split('.')[-1] | ||
return plt.imread(io.BytesIO(urllib.request.urlopen(url).read()), ext) | ||
|
||
def load_data(dataset): | ||
out = dict() | ||
if dataset=='office-31': | ||
for domain in ['amazon','dslr','webcam']: | ||
out[domain] = load_office31_domain(domain) | ||
elif dataset=='office-caltech': | ||
return load_office_caltech_domain(domain) | ||
for domain in ['amazon','Caltech','dslr','webcam']: | ||
out[domain] = load_office_caltech_domain(domain) | ||
return out | ||
|
||
def load_data(dataset, domain): | ||
if dataset=='office-31': | ||
return load_office31_domain(domain) | ||
elif dataset=='office-caltech': | ||
return load_office_caltech_domain(domain) | ||
else: | ||
raise Exception("Unknown dataset") | ||
|
||
def dataset_domains(dataset): | ||
if dataset=='office-31': | ||
return ['amazon','dslr','webcam'] | ||
elif dataset=='office-caltech': | ||
return ['amazon','Caltech','dslr','webcam'] | ||
else: | ||
raise Exception("Unknown dataset") | ||
|
||
def load_office31_domain(domain): | ||
dirs = sorted(glob.glob(expanduser('~/data/office31/{}/images/*').format(domain))) | ||
x = [] | ||
y = [] | ||
for i,dir in enumerate(dirs): | ||
for file in sorted(glob.glob(dir+'/*.jpg')): | ||
x.append(load_image(file)) | ||
y.append(i) | ||
if len(x) == 0: | ||
raise Exception("No images found") | ||
return x,y | ||
|
||
def load_office_caltech_domain(domain): | ||
# Load matlab files | ||
mat_suffix = 'Caltech10' if domain == 'Caltech' else domain | ||
# labels | ||
surf_file = '../../data/office10/{}_SURF_L10.mat'.format(mat_suffix) | ||
y = scipy.io.loadmat(surf_file)['labels'] # 1..10 | ||
y = y[:,0] - 1 | ||
# caltech uses different category names | ||
caltech_cat_names = {'003':'backpack', '041':'coffee-mug', '045':'computer-keyboard', | ||
'046':'computer-monitor', '047':'computer-mouse', '101':'head-phones', | ||
'127':'laptop-101', '224':'touring-bike', '238':'video-projector'} | ||
# images | ||
index_file = '../../data/office10/{}_SURF_L10_imgs.mat'.format(mat_suffix) | ||
img_names = scipy.io.loadmat(index_file)['imgNames'][:,0] | ||
x = [] | ||
for img_name in img_names: | ||
img_name = img_name[0] | ||
# map names: | ||
if domain == 'Caltech': | ||
# example: Caltech256_projector_238_0089 | ||
# --> data/caltech256/256_ObjectCategories/238.video-projector/238_0089.jpg | ||
cat_name, cat_id, img_id = re.match(r'Caltech256_(.*)_([^_]*)_([^_]*)$', img_name).groups() | ||
if cat_id in caltech_cat_names: | ||
cat_name = caltech_cat_names[cat_id] | ||
file = '~/data/caltech256/256_ObjectCategories/{}.{}/{}_{}.jpg'.format(cat_id, cat_name, cat_id, img_id) | ||
else: | ||
# example: amazon_projector_frame_0076 --> data/office31/amazon/projector/frame_0076.jpg | ||
dom_name, cat_name, img_id = re.match(r'([^_]*)_(.*)_(frame_[^_]*)', img_name).groups() | ||
file = '~/data/office31/{}/images/{}/{}.jpg'.format(domain, cat_name, img_id) | ||
x.append(load_image(expanduser(file))) | ||
return x,y | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
# Use pretrained network for feature extraction | ||
import scipy.io | ||
import os.path | ||
import keras | ||
from keras import backend as K | ||
from keras.models import Model, Input | ||
import tensorflow as tf | ||
from os.path import expanduser | ||
import progressbar | ||
|
||
from dataset import * | ||
|
||
#------------------------------------------------------------------------------- | ||
# Configuration and initialization | ||
#------------------------------------------------------------------------------- | ||
|
||
# Set Memory allocation in tf/keras to Growth | ||
config = tf.ConfigProto() | ||
config.gpu_options.allow_growth = True | ||
sess = tf.Session(config=config) | ||
K.set_session(sess) | ||
|
||
outdir = expanduser('~/data/domain-adaptation/') | ||
|
||
padding = 32 | ||
|
||
#------------------------------------------------------------------------------- | ||
# Feature extraction | ||
#------------------------------------------------------------------------------- | ||
|
||
def extract_features(dataset, architecture, model): | ||
for domain in dataset_domains(dataset): | ||
filename = '{}/{}-{}-{}.mat'.format(outdir,dataset,domain,architecture) | ||
if os.path.isfile(filename): | ||
continue | ||
|
||
print('Loading {} {}'.format(dataset,domain)) | ||
x,y = load_data(dataset,domain) | ||
y = np.asarray(y, dtype=np.float32) + 1 | ||
|
||
print('Preprocessing') | ||
bar = progressbar.ProgressBar() | ||
x = [preprocess_for(architecture,im) for im in bar(x)] | ||
|
||
print('Calculating features for {} {}'.format(dataset,domain)) | ||
bar = progressbar.ProgressBar() | ||
features = np.stack([image_features(im,model) for im in bar(x)]) | ||
|
||
print('Saving {}'.format(filename)) | ||
scipy.io.savemat(filename,{'x':features,'y':y}) | ||
|
||
# Removes the last layer from a Keras pretrained model | ||
def remove_last_layer(model): | ||
model.layers.pop() | ||
model.outputs = [model.layers[-1].output] | ||
model.layers[-1].outbound_nodes = [] | ||
model = Model(inputs=model.inputs, outputs=model.outputs) | ||
return model | ||
|
||
def create_model(architecture): | ||
if architecture is 'inception_resnet_v2': | ||
return remove_last_layer(keras.applications.inception_resnet_v2.InceptionResNetV2(weights='imagenet')) | ||
elif architecture is 'xception': | ||
return remove_last_layer(keras.applications.xception.Xception(weights='imagenet')) | ||
elif architecture is 'inception_v3': | ||
return remove_last_layer(keras.applications.inception_v3.InceptionV3(weights='imagenet')) | ||
elif architecture is 'resnet50': | ||
return remove_last_layer(keras.applications.resnet50.ResNet50(weights='imagenet')) | ||
elif architecture is 'vgg19': | ||
return remove_last_layer(keras.applications.vgg19.VGG19(weights='imagenet')) | ||
elif architecture is 'vgg16': | ||
return remove_last_layer(keras.applications.vgg16.VGG16(weights='imagenet')) | ||
|
||
def resize_image(x,size): | ||
return np.asarray(scipy.misc.imresize(x, (size+2*padding, size+2*padding, 3)),dtype='float32') | ||
|
||
def preprocess_for(architecture, x): | ||
if architecture is 'inception_resnet_v2': | ||
return keras.applications.inception_resnet_v2.preprocess_input(resize_image(x, 259)) | ||
elif architecture is 'xception': | ||
return keras.applications.xception.preprocess_input(resize_image(x, 259)) | ||
elif architecture is 'inception_v3': | ||
return keras.applications.inception_v3.preprocess_input(resize_image(x, 259)) | ||
elif architecture is 'resnet50': | ||
return keras.applications.resnet50.preprocess_input(resize_image(x, 224)) | ||
elif architecture is 'vgg19': | ||
return keras.applications.vgg19.preprocess_input(resize_image(x, 224)) | ||
elif architecture is 'vgg16': | ||
return keras.applications.vgg16.preprocess_input(resize_image(x, 224)) | ||
|
||
# Get features for x_in images from model (Average of 18 image representations = 9 image crops * 2 variantions (normal and horizontally flipped)) | ||
def image_features(x, model): | ||
size = x.shape[0] - 2*padding | ||
images = [] | ||
for i in range(3): | ||
for j in range(3): | ||
images.append(x[padding*i:padding*i+size, padding*j:padding*j+size, :]) | ||
images.extend([np.fliplr(x) for x in images]) | ||
# Note: xs.extend(f(x) for x in xs) results in an infinite loop, converting generator to a list first solves that | ||
return np.mean(model.predict(np.stack(images)), axis=0) | ||
|
||
#------------------------------------------------------------------------------- | ||
# Main | ||
#------------------------------------------------------------------------------- | ||
|
||
if __name__ == "__main__": | ||
for architecture in ['vgg16','vgg19','resnet50','inception_v3','xception','inception_resnet_v2']: | ||
model = create_model(architecture) | ||
for dataset in ['office-caltech','office-31']: | ||
extract_features(dataset, architecture, model) | ||
|