Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
466 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
These scripts use a pretrained VGG network to extract features from the Office and Office-caltech datasets. | ||
|
||
The extracted features are also available from http://twanvl.nl/research/domain-adaptation-2017 | ||
|
||
|
||
Needed downloads | ||
----- | ||
|
||
Download VGG network: | ||
|
||
wget https://s3.amazonaws.com/lasagne/recipes/pretrained/imagenet/vgg_cnn_s.pkl | ||
|
||
Download office images: https://cs.stanford.edu/~jhoffman/domainadapt/ | ||
|
||
Download caltech images: http://www.vision.caltech.edu/Image_Datasets/Caltech101/#Download | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# The datasets | ||
import matplotlib.pyplot as plt | ||
import urllib.request | ||
import io | ||
import glob | ||
import numpy as np | ||
import scipy.io | ||
import scipy.ndimage | ||
import re | ||
from network import resize | ||
|
||
def load_image(filename): | ||
return scipy.ndimage.imread(filename, mode='RGB') | ||
|
||
def load_image_url(url): | ||
ext = url.split('.')[-1] | ||
return plt.imread(io.BytesIO(urllib.request.urlopen(url).read()), ext) | ||
|
||
def load_office31_domain(domain): | ||
dirs = sorted(glob.glob('~/data/office31/{}/images/*'.format(domain))) | ||
x = [] | ||
y = [] | ||
for i,dir in enumerate(dirs): | ||
for file in sorted(glob.glob(dir+'/*.jpg')): | ||
x.append(load_image(file)) | ||
y.append(i) | ||
return x,y | ||
|
||
def load_and_preprocess_office31(network): | ||
domains = ['amazon','dslr','webcam'] | ||
x = {} | ||
y = {} | ||
for domain in domains: | ||
xd, yd = load_office31_domain(domain) | ||
yd = np.array(yd, dtype=np.int32) | ||
xd = network.preprocess_many(xd) | ||
x[domain] = xd | ||
y[domain] = yd | ||
return domains, x, y | ||
|
||
def load_and_resize_office31(): | ||
domains = ['amazon','dslr','webcam'] | ||
x = {} | ||
y = {} | ||
for domain in domains: | ||
xd, yd = load_office31_domain(domain) | ||
yd = np.array(yd, dtype=np.int32) | ||
xd = np.array([resize(x) for x in xd], dtype=np.float32) | ||
x[domain] = xd | ||
y[domain] = yd | ||
return domains, x, y | ||
|
||
def load_office_caltech_domain(domain): | ||
# Load matlab files | ||
mat_suffix = 'Caltech10' if domain == 'Caltech' else domain | ||
# labels | ||
surf_file = '../data/office10/{}_SURF_L10.mat'.format(mat_suffix) | ||
y = scipy.io.loadmat(surf_file)['labels'] # 1..10 | ||
y = y[:,0] - 1 | ||
# caltech uses different category names | ||
caltech_cat_names = {'003':'backpack', '041':'coffee-mug', '045':'computer-keyboard', | ||
'046':'computer-monitor', '047':'computer-mouse', '101':'head-phones', | ||
'127':'laptop-101', '224':'touring-bike', '238':'video-projector'} | ||
# images | ||
index_file = '../data/office10/{}_SURF_L10_imgs.mat'.format(mat_suffix) | ||
img_names = scipy.io.loadmat(index_file)['imgNames'][:,0] | ||
x = [] | ||
for img_name in img_names: | ||
img_name = img_name[0] | ||
# map names: | ||
if domain == 'Caltech': | ||
# example: Caltech256_projector_238_0089 | ||
# --> data/caltech256/256_ObjectCategories/238.video-projector/238_0089.jpg | ||
cat_name, cat_id, img_id = re.match(r'Caltech256_(.*)_([^_]*)_([^_]*)$', img_name).groups() | ||
if cat_id in caltech_cat_names: | ||
cat_name = caltech_cat_names[cat_id] | ||
file = '~/data/caltech256/256_ObjectCategories/{}.{}/{}_{}.jpg'.format(cat_id, cat_name, cat_id, img_id) | ||
else: | ||
# example: amazon_projector_frame_0076 --> data/office31/amazon/projector/frame_0076.jpg | ||
dom_name, cat_name, img_id = re.match(r'([^_]*)_(.*)_(frame_[^_]*)', img_name).groups() | ||
file = '~/data/office31/{}/images/{}/{}.jpg'.format(domain, cat_name, img_id) | ||
x.append(load_image(file)) | ||
return x,y | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Use pretrained network for feature extraction | ||
from network import * | ||
from dataset import * | ||
import scipy.io | ||
|
||
domains = ['amazon','dslr','webcam'] | ||
layers = ['pool5','fc6','fc7','fc8'] | ||
|
||
print('Initializing network') | ||
network = PretrainedNetwork() | ||
|
||
for domain in domains: | ||
# load data | ||
print('Loading {}'.format(domain)) | ||
x,y = load_office31_domain(domain) | ||
y = np.array(y) | ||
# preprocess images | ||
print('Preprocessing {}'.format(domain)) | ||
x = network.preprocess_many(x) | ||
# for each layer | ||
for layer in layers: | ||
print('Generating {} {}'.format(domain, layer)) | ||
out = network.get_features(x, layer, preprocess=False) | ||
# Write to matlab matrix file | ||
scipy.io.savemat('data/office-vgg-{}-{}.mat'.format(domain,layer),{'x':out,'y':y+1}) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Use pretrained network for feature extraction, | ||
# for the office-caltech dataset (see https://cs.stanford.edu/~jhoffman/domainadapt/) | ||
# get image names from matlab files. | ||
from network import * | ||
from dataset import * | ||
import scipy.io | ||
|
||
domains = ['amazon','Caltech','dslr','webcam'] | ||
layers = ['fc6','fc7'] | ||
outdir = '~/data/office-caltech' | ||
|
||
print('Initializing network') | ||
network = PretrainedNetwork() | ||
|
||
for domain in domains: | ||
# load data | ||
print('Loading {}'.format(domain)) | ||
x,y = load_office_caltech_domain(domain) | ||
y = np.array(y) | ||
# preprocess images | ||
print('Preprocessing {}'.format(domain)) | ||
print('NB: shape is {}'.format(np.shape(x[0]))) | ||
xp = network.preprocess_many(x) | ||
# for each layer | ||
for layer in layers: | ||
print('Generating {} {}'.format(domain, layer)) | ||
out = network.get_features(xp, layer, preprocess=False) | ||
# Write to matlab matrix file | ||
scipy.io.savemat('{}/office-caltech-vgg-{}-{}.mat'.format(outdir,domain,layer),{'x':out,'y':y+1}) | ||
|
||
# With pooling | ||
# preprocess images | ||
xs = [] | ||
for flip in [False,True]: | ||
for crop_h,crop_w in [(0,0),(0,1),(1,0),(1,1),(0.5,0.5)]: | ||
print('Preprocessing {} ({}) '.format(domain,len(xs)), end='\r') | ||
px = network.preprocess_many(x,crop_h=crop_h,crop_w=crop_w,flip=flip,size=224*3//2) | ||
xs.append(px) | ||
# for each layer | ||
for layer in layers: | ||
print('Generating {} {}'.format(domain, layer)) | ||
out = [network.get_features(x, layer, preprocess=False) for x in xs] | ||
sums = sum(out) | ||
maxes = np.amax(out, 0) | ||
# Write to matlab matrix file | ||
scipy.io.savemat('{}/office-caltech-vgg-sumpool-{}-{}.mat'.format(outdir,domain,layer),{'x':sums,'y':y+1}) | ||
scipy.io.savemat('{}/office-caltech-vgg-maxpool-{}-{}.mat'.format(outdir,domain,layer),{'x':maxes,'y':y+1}) | ||
scipy.io.savemat('{}/office-caltech-vgg-catpool-{}-{}.mat'.format(outdir,domain,layer),{'x':np.array(out),'y':y+1}) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# Use pretrained network for feature extraction | ||
from network import * | ||
from dataset import * | ||
import scipy.io | ||
import numpy as np | ||
|
||
domains = ['amazon','dslr','webcam'] | ||
layers = ['fc6','fc7','fc8'] | ||
|
||
print('Initializing network') | ||
network = PretrainedNetwork() | ||
outdir = '~/data/office31/vgg' | ||
|
||
for domain in domains: | ||
# load data | ||
print('Loading {}'.format(domain)) | ||
x,y = load_office31_domain(domain) | ||
y = np.array(y) | ||
# preprocess images | ||
xs = [] | ||
for flip in [False,True]: | ||
for crop_h,crop_w in [(0,0),(0,1),(1,0),(1,1),(0.5,0.5)]: | ||
print('Preprocessing {} ({}) '.format(domain,len(xs)), end='\r') | ||
px = network.preprocess_many(x,crop_h=crop_h,crop_w=crop_w,flip=flip,size=224*3//2) | ||
xs.append(px) | ||
# for each layer | ||
for layer in layers: | ||
print('Generating {} {}'.format(domain, layer)) | ||
out = [network.get_features(x, layer, preprocess=False) for x in xs] | ||
sums = sum(out) | ||
maxes = np.amax(out, 0) | ||
# Write to matlab matrix file | ||
scipy.io.savemat('{}/office-vgg-sumpool-{}-{}.mat'.format(outdir,domain,layer),{'x':sums,'y':y+1}) | ||
scipy.io.savemat('{}/office-vgg-maxpool-{}-{}.mat'.format(outdir,domain,layer),{'x':maxes,'y':y+1}) | ||
scipy.io.savemat('{}/office-vgg-catpool-{}-{}.mat'.format(outdir,domain,layer),{'x':np.array(out),'y':y+1}) | ||
|
Oops, something went wrong.