diff --git a/src/Microsoft.ML.Api/SchemaDefinition.cs b/src/Microsoft.ML.Api/SchemaDefinition.cs index c93403b59e..5f84712625 100644 --- a/src/Microsoft.ML.Api/SchemaDefinition.cs +++ b/src/Microsoft.ML.Api/SchemaDefinition.cs @@ -332,8 +332,9 @@ public static SchemaDefinition Create(Type userType) if (fieldInfo.GetCustomAttribute() != null) continue; - var mappingAttr = fieldInfo.GetCustomAttribute(); - var name = mappingAttr == null ? fieldInfo.Name : (mappingAttr.Name ?? fieldInfo.Name); + var mappingAttr = fieldInfo.GetCustomAttribute(); + var mappingNameAttr = fieldInfo.GetCustomAttribute(); + string name = mappingAttr?.Name ?? mappingNameAttr?.Name ?? fieldInfo.Name; // Disallow duplicate names, because the field enumeration order is not actually // well defined, so we are not gauranteed to have consistent "hiding" from run to // run, across different .NET versions. diff --git a/src/Microsoft.ML/CSharpApi.cs b/src/Microsoft.ML/CSharpApi.cs index 46002c5abf..ecea73a495 100644 --- a/src/Microsoft.ML/CSharpApi.cs +++ b/src/Microsoft.ML/CSharpApi.cs @@ -12300,13 +12300,13 @@ namespace Runtime { public abstract class CalibratorTrainer : ComponentKind {} + + /// /// /// public sealed class FixedPlattCalibratorCalibratorTrainer : CalibratorTrainer { - - /// /// The slope parameter of f(x) = 1 / (1 + exp(-slope * x + offset) /// @@ -12320,45 +12320,45 @@ public sealed class FixedPlattCalibratorCalibratorTrainer : CalibratorTrainer internal override string ComponentName => "FixedPlattCalibrator"; } + + /// /// /// public sealed class NaiveCalibratorCalibratorTrainer : CalibratorTrainer { - - internal override string ComponentName => "NaiveCalibrator"; } + + /// /// /// public sealed class PavCalibratorCalibratorTrainer : CalibratorTrainer { - - internal override string ComponentName => "PavCalibrator"; } + + /// /// Platt calibration. /// public sealed class PlattCalibratorCalibratorTrainer : CalibratorTrainer { - - internal override string ComponentName => "PlattCalibrator"; } public abstract class ClassificationLossFunction : ComponentKind {} + + /// /// Exponential loss. /// public sealed class ExpLossClassificationLossFunction : ClassificationLossFunction { - - /// /// Beta (dilation) /// @@ -12367,13 +12367,13 @@ public sealed class ExpLossClassificationLossFunction : ClassificationLossFuncti internal override string ComponentName => "ExpLoss"; } + + /// /// Hinge loss. /// public sealed class HingeLossClassificationLossFunction : ClassificationLossFunction { - - /// /// Margin value /// @@ -12382,23 +12382,23 @@ public sealed class HingeLossClassificationLossFunction : ClassificationLossFunc internal override string ComponentName => "HingeLoss"; } + + /// /// Log loss. /// public sealed class LogLossClassificationLossFunction : ClassificationLossFunction { - - internal override string ComponentName => "LogLoss"; } + + /// /// Smoothed Hinge loss. /// public sealed class SmoothedHingeLossClassificationLossFunction : ClassificationLossFunction { - - /// /// Smoothing constant /// @@ -12409,13 +12409,13 @@ public sealed class SmoothedHingeLossClassificationLossFunction : Classification public abstract class EarlyStoppingCriterion : ComponentKind {} + + /// /// Stop in case of loss of generality. /// public sealed class GLEarlyStoppingCriterion : EarlyStoppingCriterion { - - /// /// Threshold in range [0,1]. /// @@ -12425,13 +12425,13 @@ public sealed class GLEarlyStoppingCriterion : EarlyStoppingCriterion internal override string ComponentName => "GL"; } + + /// /// Stops in case of low progress. /// public sealed class LPEarlyStoppingCriterion : EarlyStoppingCriterion { - - /// /// Threshold in range [0,1]. /// @@ -12447,13 +12447,13 @@ public sealed class LPEarlyStoppingCriterion : EarlyStoppingCriterion internal override string ComponentName => "LP"; } + + /// /// Stops in case of generality to progress ration exceeds threshold. /// public sealed class PQEarlyStoppingCriterion : EarlyStoppingCriterion { - - /// /// Threshold in range [0,1]. /// @@ -12469,13 +12469,13 @@ public sealed class PQEarlyStoppingCriterion : EarlyStoppingCriterion internal override string ComponentName => "PQ"; } + + /// /// Stop if validation score exceeds threshold value. /// public sealed class TREarlyStoppingCriterion : EarlyStoppingCriterion { - - /// /// Tolerance threshold. (Non negative value) /// @@ -12485,13 +12485,13 @@ public sealed class TREarlyStoppingCriterion : EarlyStoppingCriterion internal override string ComponentName => "TR"; } + + /// /// Stops in case of consecutive loss in generality. /// public sealed class UPEarlyStoppingCriterion : EarlyStoppingCriterion { - - /// /// The window size. /// @@ -12503,13 +12503,13 @@ public sealed class UPEarlyStoppingCriterion : EarlyStoppingCriterion public abstract class FastTreeTrainer : ComponentKind {} + + /// /// Uses a logit-boost boosted tree learner to perform binary classification. /// public sealed class FastTreeBinaryClassificationFastTreeTrainer : FastTreeTrainer { - - /// /// Should we use derivatives optimized for unbalanced sets /// @@ -12856,13 +12856,13 @@ public sealed class FastTreeBinaryClassificationFastTreeTrainer : FastTreeTraine internal override string ComponentName => "FastTreeBinaryClassification"; } + + /// /// Trains gradient boosted decision trees to the LambdaRank quasi-gradient. /// public sealed class FastTreeRankingFastTreeTrainer : FastTreeTrainer { - - /// /// Comma seperated list of gains associated to each relevance label. /// @@ -13244,13 +13244,13 @@ public sealed class FastTreeRankingFastTreeTrainer : FastTreeTrainer internal override string ComponentName => "FastTreeRanking"; } + + /// /// Trains gradient boosted decision trees to fit target values using least-squares. /// public sealed class FastTreeRegressionFastTreeTrainer : FastTreeTrainer { - - /// /// Use best regression step trees? /// @@ -13592,13 +13592,13 @@ public sealed class FastTreeRegressionFastTreeTrainer : FastTreeTrainer internal override string ComponentName => "FastTreeRegression"; } + + /// /// Trains gradient boosted decision trees to fit target values using a Tweedie loss function. This learner is a generalization of Poisson, compound Poisson, and gamma regression. /// public sealed class FastTreeTweedieRegressionFastTreeTrainer : FastTreeTrainer { - - /// /// Index parameter for the Tweedie distribution, in the range [1, 2]. 1 is Poisson loss, 2 is gamma loss, and intermediate values are compound Poisson loss. /// @@ -13947,13 +13947,13 @@ public sealed class FastTreeTweedieRegressionFastTreeTrainer : FastTreeTrainer public abstract class NgramExtractor : ComponentKind {} + + /// /// Extracts NGrams from text and convert them to vector using dictionary. /// public sealed class NGramNgramExtractor : NgramExtractor { - - /// /// Ngram length /// @@ -13982,13 +13982,13 @@ public sealed class NGramNgramExtractor : NgramExtractor internal override string ComponentName => "NGram"; } + + /// /// Extracts NGrams from text and convert them to vector using hashing trick. /// public sealed class NGramHashNgramExtractor : NgramExtractor { - - /// /// Ngram length /// @@ -14029,45 +14029,45 @@ public sealed class NGramHashNgramExtractor : NgramExtractor public abstract class ParallelTraining : ComponentKind {} + + /// /// Single node machine learning process. /// public sealed class SingleParallelTraining : ParallelTraining { - - internal override string ComponentName => "Single"; } public abstract class RegressionLossFunction : ComponentKind {} + + /// /// Poisson loss. /// public sealed class PoissonLossRegressionLossFunction : RegressionLossFunction { - - internal override string ComponentName => "PoissonLoss"; } + + /// /// Squared loss. /// public sealed class SquaredLossRegressionLossFunction : RegressionLossFunction { - - internal override string ComponentName => "SquaredLoss"; } + + /// /// Tweedie loss. /// public sealed class TweedieLossRegressionLossFunction : RegressionLossFunction { - - /// /// Index parameter for the Tweedie distribution, in the range [1, 2]. 1 is Poisson loss, 2 is gamma loss, and intermediate values are compound Poisson loss. /// @@ -14078,13 +14078,13 @@ public sealed class TweedieLossRegressionLossFunction : RegressionLossFunction public abstract class SDCAClassificationLossFunction : ComponentKind {} + + /// /// Hinge loss. /// public sealed class HingeLossSDCAClassificationLossFunction : SDCAClassificationLossFunction { - - /// /// Margin value /// @@ -14093,23 +14093,23 @@ public sealed class HingeLossSDCAClassificationLossFunction : SDCAClassification internal override string ComponentName => "HingeLoss"; } + + /// /// Log loss. /// public sealed class LogLossSDCAClassificationLossFunction : SDCAClassificationLossFunction { - - internal override string ComponentName => "LogLoss"; } + + /// /// Smoothed Hinge loss. /// public sealed class SmoothedHingeLossSDCAClassificationLossFunction : SDCAClassificationLossFunction { - - /// /// Smoothing constant /// @@ -14120,25 +14120,25 @@ public sealed class SmoothedHingeLossSDCAClassificationLossFunction : SDCAClassi public abstract class SDCARegressionLossFunction : ComponentKind {} + + /// /// Squared loss. /// public sealed class SquaredLossSDCARegressionLossFunction : SDCARegressionLossFunction { - - internal override string ComponentName => "SquaredLoss"; } public abstract class StopWordsRemover : ComponentKind {} + + /// /// Remover with list of stopwords specified by the user. /// public sealed class CustomStopWordsRemover : StopWordsRemover { - - /// /// List of stopwords /// @@ -14147,13 +14147,13 @@ public sealed class CustomStopWordsRemover : StopWordsRemover internal override string ComponentName => "Custom"; } + + /// /// Remover with predefined list of stop words. /// public sealed class PredefinedStopWordsRemover : StopWordsRemover { - - internal override string ComponentName => "Predefined"; } diff --git a/src/Microsoft.ML/Runtime/Experiment/Experiment.cs b/src/Microsoft.ML/Runtime/Experiment/Experiment.cs index aa5736bd14..9fb0560701 100644 --- a/src/Microsoft.ML/Runtime/Experiment/Experiment.cs +++ b/src/Microsoft.ML/Runtime/Experiment/Experiment.cs @@ -5,7 +5,6 @@ using System; using System.Collections.Generic; using System.IO; -using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.EntryPoints.JsonUtils; using Newtonsoft.Json; @@ -165,7 +164,7 @@ private string Serialize(string name, object input, object output) { using (var jw = new JsonTextWriter(sw)) { - jw.Formatting = Formatting.Indented; + jw.Formatting = Newtonsoft.Json.Formatting.Indented; _serializer.Serialize(jw, _helper); } return sw.ToString(); diff --git a/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs b/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs index 17643518ca..f1e45fa446 100644 --- a/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs +++ b/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs @@ -74,7 +74,7 @@ public static string GetOutputType(Type outputType) return $"Var<{GetCSharpTypeName(outputType)}>"; } - public static string GetInputType(ModuleCatalog catalog, Type inputType, + public static string GetInputType(ModuleCatalog catalog, Type inputType, Dictionary typesSymbolTable, string rootNameSpace = "") { if (inputType.IsGenericType && inputType.GetGenericTypeDefinition() == typeof(Var<>)) @@ -136,13 +136,13 @@ public static string GetOutputType(Type outputType) return $"{enumName}"; default: if (isNullable) - return rootNameSpace+typesSymbolTable[type.FullName]; + return rootNameSpace + typesSymbolTable[type.FullName]; if (isOptional) - return $"Optional<{rootNameSpace+typesSymbolTable[type.FullName]}>"; + return $"Optional<{rootNameSpace + typesSymbolTable[type.FullName]}>"; if (typesSymbolTable.ContainsKey(type.FullName)) return rootNameSpace + typesSymbolTable[type.FullName]; else - return GetSymbolFromType(typesSymbolTable, type.FullName, rootNameSpace); + return GetSymbolFromType(typesSymbolTable, type, rootNameSpace); } } @@ -177,7 +177,7 @@ public static string Capitalize(string s) return char.ToUpperInvariant(s[0]) + s.Substring(1); } - public static string GetValue(ModuleCatalog catalog, Type fieldType, object fieldValue, + public static string GetValue(ModuleCatalog catalog, Type fieldType, object fieldValue, Dictionary typesSymbolTable, string rootNameSpace = "") { if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Var<>)) @@ -299,7 +299,7 @@ public static string Capitalize(string s) var properties = propertyBag.Count > 0 ? $" {{ {string.Join(", ", propertyBag)} }}" : ""; return $"new {GetComponentName(componentInfo)}(){properties}"; case TlcModule.DataKind.Unknown: - return $"new {rootNameSpace+typesSymbolTable[fieldType.FullName]}()"; + return $"new {rootNameSpace + typesSymbolTable[fieldType.FullName]}()"; default: return fieldValue.ToString(); } @@ -321,7 +321,7 @@ public static string GetEnumName(Type type, Dictionary typesSymb if (typesSymbolTable.ContainsKey(type.FullName)) return rootNamespace + typesSymbolTable[type.FullName]; else - return GetSymbolFromType(typesSymbolTable, type.FullName, rootNamespace); + return GetSymbolFromType(typesSymbolTable, type, rootNamespace); } public static string GetJsonFromField(string fieldName, Type fieldType) @@ -495,16 +495,72 @@ private void GenerateFooter(IndentingTextWriter writer) writer.WriteLine(); } - static string GetSymbolFromType(Dictionary typesSymbolTable, string fullTypeName, string currentNamespace) + /// + /// This methods creates a unique name for a class/struct/enum, given a type and a namespace. + /// It generates the name based on the property of the type + /// (see description here https://msdn.microsoft.com/en-us/library/system.type.fullname(v=vs.110).aspx). + /// Example: Assume we have the following structure in namespace X.Y: + /// class A { + /// class B { + /// enum C { + /// Value1, + /// Value2 + /// } + /// } + /// } + /// The full name of C would be X.Y.A+B+C. This method will generate the name "ABC" from it. In case + /// A is generic with one generic type, then the full name of typeof(A<float>.B.C) would be X.Y.A`1+B+C[[System.Single]]. + /// In this case, this method will generate the name "ASingleBC". + /// + /// A dictionary containing the names of the classes already generated. + /// This parameter is only used to ensure that the newly generated name is unique. + /// The type for which to generate the new name. + /// The namespace prefix to the new name. + /// A unique name derived from the given type and namespace. + private static string GetSymbolFromType(Dictionary typesSymbolTable, Type type, string currentNamespace) { - var names = typesSymbolTable.Select(kvp => kvp.Value); - char dim = fullTypeName.Contains('+') ? '+' : '.'; + var fullTypeName = type.FullName; string name = currentNamespace != "" ? currentNamespace + '.' : ""; - if (fullTypeName.Contains('+')) - name += fullTypeName.Substring(0, fullTypeName.LastIndexOf('+')).Substring(fullTypeName.LastIndexOf('.') + 1); + int bracketIndex = fullTypeName.IndexOf('['); + Type[] genericTypes = null; + if (type.IsGenericType) + genericTypes = type.GetGenericArguments(); + if (bracketIndex > 0) + { + Contracts.AssertValue(genericTypes); + fullTypeName = fullTypeName.Substring(0, bracketIndex); + } + + // When the type is nested, the names of the outer types are concatenated with a '+'. + var nestedNames = fullTypeName.Split('+'); + var baseName = nestedNames[0]; + + // We currently only handle generic types in the outer most class, support for generic inner classes + // can be added if needed. + int backTickIndex = baseName.LastIndexOf('`'); + int dotIndex = baseName.LastIndexOf('.'); + Contracts.Assert(dotIndex >= 0); + if (backTickIndex < 0) + name += baseName.Substring(dotIndex + 1); + else + { + name += baseName.Substring(dotIndex + 1, backTickIndex - dotIndex - 1); + Contracts.AssertValue(genericTypes); + if (genericTypes != null) + { + foreach (var genType in genericTypes) + { + var splitNames = genType.FullName.Split('+'); + if (splitNames[0].LastIndexOf('.') >= 0) + splitNames[0] = splitNames[0].Substring(splitNames[0].LastIndexOf('.') + 1); + name += string.Join("", splitNames); + } + } + } - name += fullTypeName.Substring(fullTypeName.LastIndexOf(dim) + 1); ; + for (int i = 1; i < nestedNames.Length; i++) + name += nestedNames[i]; Contracts.Assert(typesSymbolTable.Select(kvp => kvp.Value).All(str => string.Compare(str, name) != 0)); @@ -538,7 +594,7 @@ private void GenerateEnums(IndentingTextWriter writer, Type inputType, string cu var enumType = Enum.GetUnderlyingType(type); - _typesSymbolTable[type.FullName] = GetSymbolFromType(_typesSymbolTable, type.FullName, currentNamespace); + _typesSymbolTable[type.FullName] = GetSymbolFromType(_typesSymbolTable, type, currentNamespace); if (enumType == typeof(int)) writer.WriteLine($"public enum {_typesSymbolTable[type.FullName].Substring(_typesSymbolTable[type.FullName].LastIndexOf('.') + 1)}"); else @@ -623,7 +679,7 @@ string GetFriendlyTypeName(string currentNameSpace, string typeName) if (_typesSymbolTable.ContainsKey(type.FullName)) continue; - _typesSymbolTable[type.FullName] = GetSymbolFromType(_typesSymbolTable, type.FullName, currentNamespace); + _typesSymbolTable[type.FullName] = GetSymbolFromType(_typesSymbolTable, type, currentNamespace); string classBase = ""; if (type.IsSubclassOf(typeof(OneToOneColumn))) classBase = $" : OneToOneColumn<{_typesSymbolTable[type.FullName].Substring(_typesSymbolTable[type.FullName].LastIndexOf('.') + 1)}>, IOneToOneColumn"; @@ -889,7 +945,7 @@ string GetFriendlyTypeName(string currentNameSpace, string typeName) writer.WriteLine("}"); } - private static void GenerateInputFields(IndentingTextWriter writer, + private static void GenerateInputFields(IndentingTextWriter writer, Type inputType, ModuleCatalog catalog, Dictionary typesSymbolTable, string rootNameSpace = "") { var defaults = Activator.CreateInstance(inputType); @@ -936,7 +992,7 @@ string GetFriendlyTypeName(string currentNameSpace, string typeName) sweepableParamAttr.Name = fieldInfo.Name; writer.WriteLine(sweepableParamAttr.ToString()); } - + writer.Write($"public {inputTypeString} {GeneratorUtils.Capitalize(inputAttr.Name ?? fieldInfo.Name)} {{ get; set; }}"); var defaultValue = GeneratorUtils.GetValue(catalog, fieldInfo.FieldType, fieldInfo.GetValue(defaults), typesSymbolTable, rootNameSpace); if (defaultValue != null) @@ -1013,16 +1069,16 @@ private void GenerateComponentKind(IndentingTextWriter writer, string kind) private void GenerateComponent(IndentingTextWriter writer, ModuleCatalog.ComponentInfo component, ModuleCatalog catalog) { + GenerateEnums(writer, component.ArgumentType, "Runtime"); + writer.WriteLine(); + GenerateStructs(writer, component.ArgumentType, catalog, "Runtime"); + writer.WriteLine(); writer.WriteLine("/// "); writer.WriteLine($"/// {component.Description}"); writer.WriteLine("/// "); writer.WriteLine($"public sealed class {GeneratorUtils.GetComponentName(component)} : {component.Kind}"); writer.WriteLine("{"); writer.Indent(); - GenerateEnums(writer, component.ArgumentType, ""); - writer.WriteLine(); - GenerateStructs(writer, component.ArgumentType, catalog, ""); - writer.WriteLine(); GenerateInputFields(writer, component.ArgumentType, catalog, _typesSymbolTable, "Microsoft.ML."); writer.WriteLine($"internal override string ComponentName => \"{component.Name}\";"); writer.Outdent(); diff --git a/src/Microsoft.ML/TextLoader.cs b/src/Microsoft.ML/TextLoader.cs index f63a14611b..49be2ee84c 100644 --- a/src/Microsoft.ML/TextLoader.cs +++ b/src/Microsoft.ML/TextLoader.cs @@ -91,7 +91,7 @@ private string TypeToName(Type type) else if (type == typeof(bool)) return "BL"; else - throw new Exception("Type not implemented or supported."); //Add more types. + throw new System.NotSupportedException("Type ${type.FullName} is not implemented or supported."); //Add more types. } public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs new file mode 100644 index 0000000000..79cc2fc137 --- /dev/null +++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs @@ -0,0 +1,136 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Models; +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Trainers; +using Microsoft.ML.Transforms; +using Xunit; + +namespace Microsoft.ML.Scenarios +{ + public partial class ScenariosTests + { + [Fact] + public void TrainAndPredictIrisModelWithStringLabelTest() + { + string dataPath = GetDataPath("iris.data"); + + var pipeline = new LearningPipeline(); + + pipeline.Add(new TextLoader(dataPath, useHeader: false, separator: ",")); + + pipeline.Add(new Dictionarizer("Label")); // "IrisPlantType" is used as "Label" because of column attribute name on the field. + + pipeline.Add(new ColumnConcatenator(outputColumn: "Features", + "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")); + + pipeline.Add(new StochasticDualCoordinateAscentClassifier()); + + PredictionModel model = pipeline.Train(); + + IrisPrediction prediction = model.Predict(new IrisDataWithStringLabel() + { + SepalLength = 3.3f, + SepalWidth = 1.6f, + PetalLength = 0.2f, + PetalWidth = 5.1f, + }); + + Assert.Equal(1, prediction.PredictedLabels[0], 2); + Assert.Equal(0, prediction.PredictedLabels[1], 2); + Assert.Equal(0, prediction.PredictedLabels[2], 2); + + prediction = model.Predict(new IrisDataWithStringLabel() + { + SepalLength = 3.1f, + SepalWidth = 5.5f, + PetalLength = 2.2f, + PetalWidth = 6.4f, + }); + + Assert.Equal(0, prediction.PredictedLabels[0], 2); + Assert.Equal(0, prediction.PredictedLabels[1], 2); + Assert.Equal(1, prediction.PredictedLabels[2], 2); + + prediction = model.Predict(new IrisDataWithStringLabel() + { + SepalLength = 3.1f, + SepalWidth = 2.5f, + PetalLength = 1.2f, + PetalWidth = 4.4f, + }); + + Assert.Equal(.2, prediction.PredictedLabels[0], 1); + Assert.Equal(.8, prediction.PredictedLabels[1], 1); + Assert.Equal(0, prediction.PredictedLabels[2], 2); + + // Note: Testing against the same data set as a simple way to test evaluation. + // This isn't appropriate in real-world scenarios. + string testDataPath = GetDataPath("iris.data"); + var testData = new TextLoader(testDataPath, useHeader: false, separator: ","); + + var evaluator = new ClassificationEvaluator(); + evaluator.OutputTopKAcc = 3; + ClassificationMetrics metrics = evaluator.Evaluate(model, testData); + + Assert.Equal(.98, metrics.AccuracyMacro); + Assert.Equal(.98, metrics.AccuracyMicro, 2); + Assert.Equal(.06, metrics.LogLoss, 2); + Assert.InRange(metrics.LogLossReduction, 94, 96); + Assert.Equal(1, metrics.TopKAccuracy); + + Assert.Equal(3, metrics.PerClassLogLoss.Length); + Assert.Equal(0, metrics.PerClassLogLoss[0], 1); + Assert.Equal(.1, metrics.PerClassLogLoss[1], 1); + Assert.Equal(.1, metrics.PerClassLogLoss[2], 1); + + ConfusionMatrix matrix = metrics.ConfusionMatrix; + Assert.Equal(3, matrix.Order); + Assert.Equal(3, matrix.ClassNames.Count); + Assert.Equal("Iris-setosa", matrix.ClassNames[0]); + Assert.Equal("Iris-versicolor", matrix.ClassNames[1]); + Assert.Equal("Iris-virginica", matrix.ClassNames[2]); + + Assert.Equal(50, matrix[0, 0]); + Assert.Equal(50, matrix["Iris-setosa", "Iris-setosa"]); + Assert.Equal(0, matrix[0, 1]); + Assert.Equal(0, matrix["Iris-setosa", "Iris-versicolor"]); + Assert.Equal(0, matrix[0, 2]); + Assert.Equal(0, matrix["Iris-setosa", "Iris-virginica"]); + + Assert.Equal(0, matrix[1, 0]); + Assert.Equal(0, matrix["Iris-versicolor", "Iris-setosa"]); + Assert.Equal(48, matrix[1, 1]); + Assert.Equal(48, matrix["Iris-versicolor", "Iris-versicolor"]); + Assert.Equal(2, matrix[1, 2]); + Assert.Equal(2, matrix["Iris-versicolor", "Iris-virginica"]); + + Assert.Equal(0, matrix[2, 0]); + Assert.Equal(0, matrix["Iris-virginica", "Iris-setosa"]); + Assert.Equal(1, matrix[2, 1]); + Assert.Equal(1, matrix["Iris-virginica", "Iris-versicolor"]); + Assert.Equal(49, matrix[2, 2]); + Assert.Equal(49, matrix["Iris-virginica", "Iris-virginica"]); + } + + public class IrisDataWithStringLabel + { + [Column("0")] + public float PetalWidth; + + [Column("1")] + public float SepalLength; + + [Column("2")] + public float SepalWidth; + + [Column("3")] + public float PetalLength; + + [Column("4", name: "Label")] + public string IrisPlantType; + } + } +}