Skip to content

Commit

Permalink
1. Fixed problems with validation of float TextBoxes.
Browse files Browse the repository at this point in the history
2. Added border to the main window.
3. Added (un)select all to categories ListBox control.
4. Added dependency property initialization helpers to the BindableListBox control.
5. Fixed problems with crashing when typing prohibited character to the hidden layers configuration TextBox.
6. Updated assembly versions (merge to master).
  • Loading branch information
Szymon Bartnik committed Jun 14, 2016
1 parent 3a1dfb4 commit 2bceb94
Show file tree
Hide file tree
Showing 14 changed files with 152 additions and 84 deletions.
6 changes: 3 additions & 3 deletions ImageClassification.Core/Engines/Classifier.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace Wkiro.ImageClassification.Core.Engines
{
internal class Classifier
{
private readonly IGuiLogger _logger;
private readonly IGuiLogger _guiLogger;
private readonly DeepBeliefNetwork _network;
private readonly ClassifierConfiguration _configuration;

Expand All @@ -18,7 +18,7 @@ internal class Classifier

private Classifier(IGuiLogger logger, ClassifierConfiguration configuration)
{
_logger = logger;
_guiLogger = logger;
_configuration = configuration;
}

Expand All @@ -37,7 +37,7 @@ public CategoryClassification ClassifyToCategory(double[] dataToClassify)
var categoryIndex = GetIndexOfResult(output);
var predictedCategory = categories.Single(x => x.Index == categoryIndex);

_logger.LogWriteLine($"Prediction: {predictedCategory}");
_guiLogger.LogWriteLine($"Prediction: {predictedCategory}");

var result = new CategoryClassification(
predictedCategory,
Expand Down
10 changes: 9 additions & 1 deletion ImageClassification.Core/Engines/DataProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
using System.IO;
using System.Linq;
using Accord.Imaging.Converters;
using NLog;
using Wkiro.ImageClassification.Core.Engines.ImagePreprocessing.Helpers;
using Wkiro.ImageClassification.Core.Infrastructure.Logging;
using Wkiro.ImageClassification.Core.Models.Configurations;
using Wkiro.ImageClassification.Core.Models.Dto;

Expand All @@ -25,7 +27,7 @@ public DataProvider(DataProviderConfiguration dataProviderconfiguration)
}

public DataProvider(
DataProviderConfiguration dataProviderConfiguration,
DataProviderConfiguration dataProviderConfiguration,
GlobalTrainerConfiguration globalTrainerConfiguration)
: this(dataProviderConfiguration)
{
Expand Down Expand Up @@ -106,6 +108,12 @@ private int GetTrainSampleCount(int allSamplesCount)

private LearningSet GetCategoryLearningSet(Category category, int numberOfCategories)
{
if (!category.Files.Any())
{
var errorInfo = $"No files found in '{category.Name}' category.";
throw new InvalidOperationException(errorInfo);
}

var inputOutputsData = new InputsOutputsData();
var imagePreprocessingStrategy = _dataProviderconfiguration.ToImagePreprocessingStrategy();

Expand Down
10 changes: 5 additions & 5 deletions ImageClassification.Core/Facades/ClassifierFacade.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,19 @@ public class ClassifierFacade
private readonly Classifier _classifier;
private readonly IStorage _storage = new Storage();

internal ClassifierFacade(DataProvider dataProvider, Classifier classifier)
internal ClassifierFacade(
DataProvider dataProvider,
Classifier classifier)
{
_classifier = classifier;
_dataProvider = dataProvider;
}

public ClassifierFacade(
string savedModelPath,
IGuiLogger logger)
public ClassifierFacade(string savedModelPath, IGuiLogger guiLogger)
{
var model = _storage.LoadModel(savedModelPath);
_dataProvider = new DataProvider(model.DataProviderConfiguration);
_classifier = new Classifier(model.Network, model.ClassifierConfiguration, logger);
_classifier = new Classifier(model.Network, model.ClassifierConfiguration, guiLogger);
}

public CategoryClassification ClassifyToCategory(string imageToClassifyPath)
Expand Down
13 changes: 7 additions & 6 deletions ImageClassification.Core/Facades/LearningFacade.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@ public class LearningFacade
private readonly DataProvider _dataProvider;
private readonly GlobalTrainerConfiguration _globalTrainerConfiguration;
private readonly SkipPhaseRequest _skipPhaseRequest;
private readonly IGuiLogger _logger;
private readonly IGuiLogger _guiLogger;

public LearningFacade(
DataProviderConfiguration dataProviderConfiguration,
GlobalTrainerConfiguration globalTrainerConfiguration,
SkipPhaseRequest skipPhaseRequest,
IGuiLogger logger)
IGuiLogger guiLogger)
{
_guiLogger = guiLogger;
_dataProvider = new DataProvider(dataProviderConfiguration, globalTrainerConfiguration);
_globalTrainerConfiguration = globalTrainerConfiguration;
_skipPhaseRequest = skipPhaseRequest;
_logger = logger;
}

public IEnumerable<Category> GetAvailableCategories()
Expand Down Expand Up @@ -55,19 +55,20 @@ private ClassifierFacade RunTrainingForSelectedCategoriesImpl(TrainingParameters
var layers = _globalTrainerConfiguration.HiddenLayers.ToList();
int outputLayerSize = categories.Length;
layers.Add(outputLayerSize);

var trainer = new Trainer(new TrainerConfiguration
{
Layers = layers.ToArray(),
InputsOutputsData = learningSet.TrainingData.ToInputOutputsDataNative(),
}, _skipPhaseRequest, _logger);
}, _skipPhaseRequest, _guiLogger);

trainer.RunTraining1(trainingParameters.Training1Parameters);
trainer.RunTraining2(trainingParameters.Training2Parameters);

trainer.CheckAccuracy(learningSet.TestData.ToInputOutputsDataNative());

var classifierConfiguration = new ClassifierConfiguration() { Categories = categories };
var classifier = new Classifier(trainer.NeuralNetwork, classifierConfiguration, _logger);
var classifierConfiguration = new ClassifierConfiguration { Categories = categories };
var classifier = new Classifier(trainer.NeuralNetwork, classifierConfiguration, _guiLogger);

var classifierFacade = new ClassifierFacade(_dataProvider, classifier);
return classifierFacade;
Expand Down
2 changes: 1 addition & 1 deletion ImageClassification.Core/Properties/AssemblyInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@
// The following GUID is for the ID of the typelib if this project is exposed to COM
[assembly: Guid("81b753f3-b336-4763-ad79-ebf62859750f")]

[assembly: AssemblyVersion("1.1.*")]
[assembly: AssemblyVersion("2.0.*")]
1 change: 1 addition & 0 deletions ImageClassification.Gui/App.xaml.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ public partial class App : Application
{
public App()
{
FrameworkCompatibilityPreferences.KeepTextBoxDisplaySynchronizedWithTextProperty = false;
Startup += OnAppStartup;
}

Expand Down
63 changes: 55 additions & 8 deletions ImageClassification.Gui/Helpers/BindableListBox.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
using System.Collections;
using System;
using System.Collections;
using System.Collections.ObjectModel;
using System.Linq;
using System.Linq.Expressions;
using System.Windows;
using System.Windows.Controls;
using Wkiro.ImageClassification.Core.Models.Dto;

namespace Wkiro.ImageClassification.Gui.Helpers
{
Expand All @@ -13,12 +18,12 @@ public BindableListBox()

private void ListBoxCustom_SelectionChanged(object sender, SelectionChangedEventArgs e)
{
SelectedItemsList = SelectedItems;
SelectedItemsList = new ObservableCollection<Category>(SelectedItems.Cast<Category>());
}

public IList SelectedItemsList
public ObservableCollection<Category> SelectedItemsList
{
get { return (IList)GetValue(SelectedItemsListProperty); }
get { return (ObservableCollection<Category>)GetValue(SelectedItemsListProperty); }
set { SetValue(SelectedItemsListProperty, value); }
}

Expand All @@ -35,9 +40,51 @@ protected override void OnItemsSourceChanged(IEnumerable oldValue, IEnumerable n
SelectAll();
}

public static readonly DependencyProperty SelectedItemsListProperty = DependencyProperty
.Register("SelectedItemsList", typeof(IList), typeof(BindableListBox), new PropertyMetadata(null));
public static readonly DependencyProperty SelectAllOnSourceChangeProperty = DependencyProperty
.Register("SelectAllOnSourceChange", typeof(bool), typeof(BindableListBox), new PropertyMetadata(defaultValue: false));
public static readonly DependencyProperty SelectedItemsListProperty = RegisterProperty(x => x.SelectedItemsList, null, x => x.OnSelectedItemsListChanged());

private void OnSelectedItemsListChanged()
{
SetSelectedItems(SelectedItemsList);
}

public static readonly DependencyProperty SelectAllOnSourceChangeProperty = RegisterProperty(x => x.SelectAllOnSourceChange, false);

#region Dependency Property initialization area

private static string GetPropertyName<TObject1, T>(Expression<Func<TObject1, T>> exp)
{
var body = exp.Body;
var convertExpression = body as UnaryExpression;
if (convertExpression == null)
return ((MemberExpression) body).Member.Name;

if (convertExpression.NodeType != ExpressionType.Convert)
{
throw new ArgumentException("Invalid property expression.", nameof(exp));
}
body = convertExpression.Operand;
return ((MemberExpression)body).Member.Name;
}

private static DependencyProperty RegisterProperty<T>(
Expression<Func<BindableListBox, T>> associatedProperty,
T defaultValue,
Action<BindableListBox> valueChangedAction = null)
{
return DependencyProperty.Register(
GetPropertyName(associatedProperty),
typeof(T),
typeof(BindableListBox),
new PropertyMetadata(defaultValue, (s, e) =>
{
var sender = s as BindableListBox;
if (sender != null)
{
valueChangedAction?.Invoke(sender);
}
}));
}

#endregion
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,14 @@ public object ConvertBack(object value, Type targetType, object parameter, Cultu
.Where(x => !string.IsNullOrWhiteSpace(x))
.Select(int.Parse);

return list.ToArray();
try
{
return list.ToArray();
}
catch(Exception ex)
{
return null;
}
}
}
}
1 change: 0 additions & 1 deletion ImageClassification.Gui/ImageClassification.Gui.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@
<DesignTimeSharedInput>True</DesignTimeSharedInput>
<DependentUpon>Settings.settings</DependentUpon>
</Compile>
<Compile Include="Settings.cs" />
<Compile Include="Startup.cs" />
<Compile Include="Configuration\HardcodedConfigurationManager.cs" />
<Compile Include="ViewModels\MainWindowViewModel.Commands.cs" />
Expand Down
2 changes: 1 addition & 1 deletion ImageClassification.Gui/Properties/AssemblyInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@
// The following GUID is for the ID of the typelib if this project is exposed to COM
[assembly: Guid("1bed70ba-3ab8-46d0-a2f0-72339c362530")]

[assembly: AssemblyVersion("1.2.*")]
[assembly: AssemblyVersion("2.0.*")]
28 changes: 0 additions & 28 deletions ImageClassification.Gui/Settings.cs

This file was deleted.

35 changes: 26 additions & 9 deletions ImageClassification.Gui/ViewModels/MainWindowViewModel.Commands.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,39 @@ public partial class MainWindowViewModel
public RelayCommand SaveNetworkCommand { get; set; }
public RelayCommand ReconfigureCommand { get; set; }
public RelayCommand CancelComputingCommand { get; set; }
public RelayCommand CategoriesSelectAllCommand { get; set; }
public RelayCommand CategoriesUnselectAllCommand { get; set; }

private void InitializeCommands()
{
BrowseForTrainFilesPathCommand = new RelayCommand(BrowseForTrainFilesPath);
ConfigureNewTrainingCommand = new RelayCommand(ConfigureNewTraining);
LoadSavedNetworkCommand = new RelayCommand(LoadSavedNetwork);
SelectedCategoriesChangedCommand = new RelayCommand<object>(SelectedCategoriesChanged);

StartTrainingCommand = new RelayCommand(StartTraining);
ClassifyImageCommand = new RelayCommand(ClassifyImage);
CancelComputingCommand = new RelayCommand(CancelComputing);

SaveNetworkCommand = new RelayCommand(SaveNetwork);
ReconfigureCommand = new RelayCommand(Reconfigure);

CategoriesSelectAllCommand = new RelayCommand(() => CategoriesSelect(Select.All));
CategoriesUnselectAllCommand = new RelayCommand(() => CategoriesSelect(Select.None));
}

private void CategoriesSelect(Select selectEnum)
{
switch (selectEnum)
{
case Select.All:
SelectedCategories = new ObservableCollection<Category>(AvailableCategories);
break;
case Select.None:
SelectedCategories = new ObservableCollection<Category>();
break;
default:
throw new ArgumentOutOfRangeException(nameof(selectEnum), selectEnum, null);
}
}

private void Reconfigure()
Expand All @@ -46,9 +65,7 @@ private void BrowseForTrainFilesPath()
var directory = _dataProviderConfiguration.TrainFilesLocationPath;
var dialog = new FolderBrowserDialog
{
RootFolder = Environment.SpecialFolder.MyComputer,
Description = "Select directory containing training folders",
ShowNewFolderButton = false
RootFolder = Environment.SpecialFolder.MyComputer, Description = "Select directory containing training folders", ShowNewFolderButton = false
};

if (Directory.Exists(directory))
Expand All @@ -57,11 +74,11 @@ private void BrowseForTrainFilesPath()
if (dialog.ShowDialog() == DialogResult.OK)
_dataProviderConfiguration.TrainFilesLocationPath = dialog.SelectedPath;
}
}

private void SelectedCategoriesChanged(object categories)
{
var casted = ((IList) categories).Cast<Category>();
SelectedCategories = new ObservableCollection<Category>(casted);
}
internal enum Select
{
All,
None,
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ private void LoadSavedNetwork()
{
_classifierFacade = new ClassifierFacade(
savedModelPath: fileName,
logger: this);
guiLogger: this);
var loadedModel = _classifierFacade.GetCurrentModel();
DataProviderConfiguration = loadedModel.DataProviderConfiguration;
AvailableCategories = new ObservableCollection<Category>(loadedModel.ClassifierConfiguration.Categories);
Expand Down
Loading

0 comments on commit 2bceb94

Please sign in to comment.