-
Notifications
You must be signed in to change notification settings - Fork 96
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
174 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
# Model files | ||
*.npz | ||
*.caffemodel | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from __future__ import print_function | ||
import argparse | ||
|
||
import numpy as np | ||
|
||
from chainer import Variable, link, serializers | ||
|
||
from chainer.links import caffe | ||
|
||
from model import RealismCNN | ||
|
||
def cnn2fcn(src, dst): | ||
print('Copying layers %s -> %s:' % (src.__class__.__name__, dst.__class__.__name__)) | ||
|
||
for child in src.children(): | ||
dst_child = dst['fc8' if child.name.startswith('fc8') else child.name] | ||
|
||
if isinstance(child, link.Link): | ||
print('Copying {} ...'.format(child.name)) | ||
|
||
if child.name.startswith('fc'): | ||
dst_child.__dict__['W'].data[...] = np.reshape(child.__dict__['W'].data, dst_child.__dict__['W'].data.shape) | ||
dst_child.__dict__['b'].data[...] = np.reshape(child.__dict__['b'].data, dst_child.__dict__['b'].data.shape) | ||
else: | ||
dst_child.copyparams(child) | ||
|
||
print('\tlayer: %s -> %s' % (child.name, dst_child.name)) | ||
|
||
return dst | ||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description='Load caffe model for chainer') | ||
parser.add_argument('--caffe_model_path', default='models/realismCNN_all_iter3.caffemodel', help='Path for caffe model') | ||
parser.add_argument('--chainer_model_path', default='models/realismCNN_all_iter3.npz', help='Path for saving chainer model') | ||
args = parser.parse_args() | ||
|
||
print('Load caffe model from {} ...'.format(args.caffe_model_path)) | ||
caffe_model = caffe.CaffeFunction(args.caffe_model_path) | ||
print('Load caffe model, DONE') | ||
|
||
print('\nTurn CNN into FCN, start ...\n') | ||
chainer_model = RealismCNN() | ||
chainer_model(Variable(np.zeros((1, 3, 227, 227), dtype=np.float32), volatile='on')) | ||
chainer_model = cnn2fcn(caffe_model, chainer_model) | ||
|
||
print('\nTurn CNN into FCN, DONE. Save to {} ...'.format(args.chainer_model_path)) | ||
serializers.save_npz(args.chainer_model_path, chainer_model) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from __future__ import print_function | ||
import argparse | ||
|
||
import chainer | ||
from chainer import Variable, serializers | ||
|
||
from model import RealismCNN | ||
|
||
from utils import im_preprocess_vgg | ||
|
||
import numpy as np | ||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description='Predict a list of images wheather realistic or not') | ||
parser.add_argument('--gpu', type=int, default=0, help='GPU ID (negative value indicates CPU)') | ||
parser.add_argument('--model_path', default='models/realismCNN_all_iter3.npz', help='Path for pretrained model') | ||
parser.add_argument('--list_path', help='Path for file storing image list') | ||
parser.add_argument('--batch_size', type=int, default=10, help='Batchsize of 1 iteration') | ||
parser.add_argument('--load_size', type=int, default=256, help='Scale image to load_size') | ||
parser.add_argument('--result_path', default='result.txt', help='Path for file storing results') | ||
args = parser.parse_args() | ||
|
||
model = RealismCNN() | ||
print('Load pretrained model from {} ...'.format(args.model_path)) | ||
serializers.load_npz(args.model_path, model) | ||
if args.gpu >= 0: | ||
chainer.cuda.get_device(args.gpu).use() # Make a specified GPU current | ||
model.to_gpu() # Copy the model to the GPU | ||
|
||
print('Load images from {} ...'.format(args.list_path)) | ||
dataset = chainer.datasets.ImageDataset(paths=args.list_path, root='') | ||
print('{} images in total loaded'.format(len(dataset))) | ||
data_iterator = chainer.iterators.SerialIterator(dataset, args.batch_size, repeat=False, shuffle=False) | ||
|
||
scores = np.zeros((0, 2)) | ||
for idx, batch in enumerate(data_iterator): | ||
print('Processing batch {}->{}/{} ...'.format(idx*args.batch_size+1, min(len(dataset), (idx+1)*args.batch_size), len(dataset))) | ||
batch = [im_preprocess_vgg(np.transpose(im, [1, 2, 0]), args.load_size) for im in batch] | ||
batch = Variable(chainer.dataset.concat_examples(batch, args.gpu), volatile='on') | ||
result = chainer.cuda.to_cpu(model(batch, dropout=False).data) | ||
scores = np.vstack((scores, np.mean(result, axis=(2, 3)))) | ||
|
||
print('Processing DONE !') | ||
print('Saving result to {} ...'.format(args.result_path)) | ||
with open(args.result_path, 'w') as f: | ||
for score in scores: | ||
f.write('{},{}\n'.format(score[0], score[1])) | ||
|
||
if __name__ == '__main__': | ||
main() |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters