Skip to content

Commit

Permalink
Multinomial logistic regression (#159)
Browse files Browse the repository at this point in the history
* Generalized argmin param across dimensions

* Wrote multinomial loss and gradient

* Finished multi fitted model

* Implement dot and norm on ArgminParams

* Add test for multinomial loss and grad

* Add multi logistic regression tests

* Add docs
  • Loading branch information
YuhanLiin committed Aug 31, 2021
1 parent 992938e commit e06f0be
Show file tree
Hide file tree
Showing 4 changed files with 704 additions and 248 deletions.
1 change: 1 addition & 0 deletions algorithms/linfa-logistic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ categories = ["algorithms", "mathematics", "science"]
[dependencies]
ndarray = { version = "0.15", features = ["approx", "blas"] }
ndarray-linalg = "0.14"
ndarray-stats = "0.5.0"
num-traits = "0.2"
argmin = { version = "0.4.6", features = ["ndarrayl"] }
serde = "1.0"
Expand Down
45 changes: 27 additions & 18 deletions algorithms/linfa-logistic/src/argmin_param.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//! This module defines newtypes for ndarray's Array1.
//! This module defines newtypes for ndarray's Array.
//!
//! This is necessary to be able to abstract over floats (f32 and f64) so that
//! the logistic regression code can be abstract in the float type it works
Expand All @@ -8,51 +8,60 @@

use crate::float::Float;
use argmin::prelude::*;
use ndarray::Array1;
use ndarray::{Array, ArrayBase, Data, Dimension, Zip};
use serde::{Deserialize, Serialize};

pub fn elem_dot<F: linfa::Float, A1: Data<Elem = F>, A2: Data<Elem = F>, D: Dimension>(
a: &ArrayBase<A1, D>,
b: &ArrayBase<A2, D>,
) -> F {
Zip::from(a)
.and(b)
.fold(F::zero(), |acc, &a, &b| acc + a * b)
}

#[derive(Serialize, Clone, Deserialize, Debug, Default)]
pub struct ArgminParam<F>(pub Array1<F>);
pub struct ArgminParam<F, D: Dimension>(pub Array<F, D>);

impl<F> ArgminParam<F> {
impl<F, D: Dimension> ArgminParam<F, D> {
#[inline]
pub fn as_array(&self) -> &Array1<F> {
pub fn as_array(&self) -> &Array<F, D> {
&self.0
}
}

impl<F: Float> ArgminSub<ArgminParam<F>, ArgminParam<F>> for ArgminParam<F> {
fn sub(&self, other: &ArgminParam<F>) -> ArgminParam<F> {
impl<F: Float, D: Dimension> ArgminSub<ArgminParam<F, D>, ArgminParam<F, D>> for ArgminParam<F, D> {
fn sub(&self, other: &ArgminParam<F, D>) -> ArgminParam<F, D> {
ArgminParam(&self.0 - &other.0)
}
}

impl<F: Float> ArgminAdd<ArgminParam<F>, ArgminParam<F>> for ArgminParam<F> {
fn add(&self, other: &ArgminParam<F>) -> ArgminParam<F> {
impl<F: Float, D: Dimension> ArgminAdd<ArgminParam<F, D>, ArgminParam<F, D>> for ArgminParam<F, D> {
fn add(&self, other: &ArgminParam<F, D>) -> ArgminParam<F, D> {
ArgminParam(&self.0 + &other.0)
}
}

impl<F: Float> ArgminDot<ArgminParam<F>, F> for ArgminParam<F> {
fn dot(&self, other: &ArgminParam<F>) -> F {
self.0.dot(&other.0)
impl<F: Float, D: Dimension> ArgminDot<ArgminParam<F, D>, F> for ArgminParam<F, D> {
fn dot(&self, other: &ArgminParam<F, D>) -> F {
elem_dot(&self.0, &other.0)
}
}

impl<F: Float> ArgminNorm<F> for ArgminParam<F> {
impl<F: Float, D: Dimension> ArgminNorm<F> for ArgminParam<F, D> {
fn norm(&self) -> F {
self.0.dot(&self.0)
num_traits::Float::sqrt(elem_dot(&self.0, &self.0))
}
}

impl<F: Float> ArgminMul<F, ArgminParam<F>> for ArgminParam<F> {
fn mul(&self, other: &F) -> ArgminParam<F> {
impl<F: Float, D: Dimension> ArgminMul<F, ArgminParam<F, D>> for ArgminParam<F, D> {
fn mul(&self, other: &F) -> ArgminParam<F, D> {
ArgminParam(&self.0 * *other)
}
}

impl<F: Float> ArgminMul<ArgminParam<F>, ArgminParam<F>> for ArgminParam<F> {
fn mul(&self, other: &ArgminParam<F>) -> ArgminParam<F> {
impl<F: Float, D: Dimension> ArgminMul<ArgminParam<F, D>, ArgminParam<F, D>> for ArgminParam<F, D> {
fn mul(&self, other: &ArgminParam<F, D>) -> ArgminParam<F, D> {
ArgminParam(&self.0 * &other.0)
}
}
13 changes: 7 additions & 6 deletions algorithms/linfa-logistic/src/float.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::argmin_param::ArgminParam;
use argmin::prelude::{ArgminFloat, ArgminMul};
use ndarray::NdFloat;
use ndarray::{Dimension, Ix1, Ix2, NdFloat};
use ndarray_linalg::Lapack;
use num_traits::FromPrimitive;

Expand All @@ -13,21 +13,22 @@ pub trait Float:
+ Default
+ Clone
+ FromPrimitive
+ ArgminMul<ArgminParam<Self>, ArgminParam<Self>>
+ ArgminMul<ArgminParam<Self, Ix1>, ArgminParam<Self, Ix1>>
+ ArgminMul<ArgminParam<Self, Ix2>, ArgminParam<Self, Ix2>>
+ linfa::Float
{
const POSITIVE_LABEL: Self;
const NEGATIVE_LABEL: Self;
}

impl ArgminMul<ArgminParam<Self>, ArgminParam<Self>> for f64 {
fn mul(&self, other: &ArgminParam<Self>) -> ArgminParam<Self> {
impl<D: Dimension> ArgminMul<ArgminParam<Self, D>, ArgminParam<Self, D>> for f64 {
fn mul(&self, other: &ArgminParam<Self, D>) -> ArgminParam<Self, D> {
ArgminParam(&other.0 * *self)
}
}

impl ArgminMul<ArgminParam<Self>, ArgminParam<Self>> for f32 {
fn mul(&self, other: &ArgminParam<Self>) -> ArgminParam<Self> {
impl<D: Dimension> ArgminMul<ArgminParam<Self, D>, ArgminParam<Self, D>> for f32 {
fn mul(&self, other: &ArgminParam<Self, D>) -> ArgminParam<Self, D> {
ArgminParam(&other.0 * *self)
}
}
Expand Down
Loading

0 comments on commit e06f0be

Please sign in to comment.