Skip to content

Commit

Permalink
return and not print rating_metrics; option to choose subset of modul…
Browse files Browse the repository at this point in the history
…es for irnv2 pooled
  • Loading branch information
subpic committed Mar 23, 2020
1 parent 7c87548 commit be090b7
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
18 changes: 10 additions & 8 deletions applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def model_inception_pooled(input_shape=(None, None, 3), indexes=list(range(11)),
Similar to `model_inception_multigap`.
* input_shape: shape of the input images
* indexes: indices to use from the usual GAPs
* indexes: indices to use from the usual pools
* pool_size: spatial extend of the MLSP features
* name: name of the model
* return_sizes: return the sizes of each layer: (model, pool_sizes)
Expand Down Expand Up @@ -266,12 +266,13 @@ def model_inception_pooled(input_shape=(None, None, 3), indexes=list(range(11)),
else:
return model

def model_inceptionresnet_pooled(input_shape=(None, None, 3), pool_size=(5, 5),
name='', return_sizes=False):
def model_inceptionresnet_pooled(input_shape=(None, None, 3), indexes=list(range(43)),
pool_size=(5, 5), name='', return_sizes=False):
"""
Returns the wide MLSP features, spatially pooled, from InceptionResNetV2.
* input_shape: shape of the input images
* indexes: indices of the modules to use
* pool_size: spatial extend of the MLSP features
* name: name of the model
* return_sizes: return the sizes of each layer: (model, pool_sizes)
Expand All @@ -288,11 +289,11 @@ def model_inceptionresnet_pooled(input_shape=(None, None, 3), pool_size=(5, 5),
name='feature_resizer')

feature_layers = [l for l in model_base.layers if 'mixed' in l.name]
feature_layers = [feature_layers[i] for i in indexes]
pools = [ImageResizer(l.output) for l in feature_layers]
conc_pools = Concatenate(name='conc_pools', axis=3)(pools)

model = Model(inputs = model_base.input,
outputs = conc_pools)
model = Model(inputs = model_base.input, outputs = conc_pools)
if name: model.name = name

if return_sizes:
Expand Down Expand Up @@ -400,14 +401,15 @@ def rating_metrics(y_true, y_pred, show_plot=True):
p_srcc = np.round(srcc(y_true, y_pred),3)
p_mae = np.round(np.mean(np.abs(y_true - y_pred)),3)
p_rmse = np.round(np.sqrt(np.mean((y_true - y_pred)**2)),3)
print('SRCC: {} | PLCC: {} | MAE: {} | RMSE: {}'.\
format(p_srcc, p_plcc, p_mae, p_rmse))


if show_plot:
print('SRCC: {} | PLCC: {} | MAE: {} | RMSE: {}'.\
format(p_srcc, p_plcc, p_mae, p_rmse))
plt.plot(y_true, y_pred,'.',markersize=1)
plt.xlabel('ground-truth')
plt.ylabel('predicted')
plt.show()
return (p_srcc, p_plcc, p_mae, p_rmse)

def get_train_test_sets(ids, stratify_on='MOS', test_size=(0.2, 0.2),
save_path=None, show_histograms=False,
Expand Down
2 changes: 1 addition & 1 deletion image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def check_images(image_dir, image_types =\
pattern = os.path.join(image_dir, imtype)
file_list.extend(glob.glob(pattern))
print('Found', len(file_list), 'images')

image_names_err = []
image_names_all = []
for (i, file_path) in enumerate(file_list):
Expand Down

0 comments on commit be090b7

Please sign in to comment.