Skip to content

Commit 5b2712f

Browse files
committed
a 1-dimensional Gaussian mixture example from ChatGPT 3.5
1 parent 4127d68 commit 5b2712f

File tree

2 files changed

+111
-0
lines changed

2 files changed

+111
-0
lines changed

gaussian_mixture/Cargo.toml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
[package]
2+
name = "gaussian_mixture"
3+
version = "0.1.0"
4+
edition = "2021"
5+
6+
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
7+
8+
[dependencies]
9+
rand = "*"
10+
nalgebra = "*"

gaussian_mixture/src/main.rs

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
use rand::Rng;
2+
use nalgebra::{DVector, DMatrix};
3+
4+
fn main() {
5+
// Generate synthetic data
6+
let data = generate_data();
7+
8+
// Initialize parameters randomly
9+
let num_components = 2;
10+
let (mut weights, mut means, mut covariances) = initialize_parameters(&data, num_components);
11+
12+
// Run the EM algorithm
13+
let num_iterations = 100;
14+
for _ in 0..num_iterations {
15+
// Expectation step
16+
let responsibilities = expectation_step(&data, &weights, &means, &covariances);
17+
18+
// Maximization step
19+
maximization_step(&data, &responsibilities, &mut weights, &mut means, &mut covariances);
20+
}
21+
22+
// Print the final parameters
23+
println!("Final weights: {:?}", weights);
24+
println!("Final means: {:?}", means);
25+
println!("Final covariances: {:?}", covariances);
26+
}
27+
28+
fn generate_data() -> DVector<f64> {
29+
let mut rng = rand::thread_rng();
30+
let num_samples = 1000;
31+
let data: DVector<f64> = (0..num_samples)
32+
.map(|_| rng.gen_range(-5.0..5.0))
33+
.collect();
34+
data
35+
}
36+
37+
fn initialize_parameters(data: &DVector<f64>, num_components: usize) -> (DVector<f64>, DVector<f64>, DVector<f64>) {
38+
let mut rng = rand::thread_rng();
39+
40+
// Initialize weights uniformly
41+
let weights = DVector::from_element(num_components, 1.0 / num_components as f64);
42+
43+
// Randomly initialize means and covariances based on data statistics
44+
let mean_min = data.min();
45+
let mean_max = data.max();
46+
let means: DVector<f64> = (0..num_components)
47+
.map(|_| rng.gen_range(mean_min..mean_max))
48+
.collect();
49+
50+
let covariance_min = 0.1;
51+
let covariance_max = 1.0;
52+
let covariances: DVector<f64> = (0..num_components)
53+
.map(|_| rng.gen_range(covariance_min..covariance_max))
54+
.collect();
55+
56+
(weights, means, covariances)
57+
}
58+
59+
fn expectation_step(data: &DVector<f64>, weights: &DVector<f64>, means: &DVector<f64>, covariances: &DVector<f64>) -> DMatrix<f64> {
60+
let num_samples = data.len();
61+
let num_components = weights.len();
62+
63+
let mut responsibilities = DMatrix::zeros(num_samples, num_components);
64+
65+
for i in 0..num_samples {
66+
for j in 0..num_components {
67+
responsibilities[(i, j)] = weights[j]
68+
* gaussian_pdf(data[i], means[j], covariances[j])
69+
/ (0..num_components).map(|k| weights[k] * gaussian_pdf(data[i], means[k], covariances[k])).sum::<f64>();
70+
}
71+
}
72+
73+
responsibilities
74+
}
75+
76+
fn maximization_step(data: &DVector<f64>, responsibilities: &DMatrix<f64>, weights: &mut DVector<f64>, means: &mut DVector<f64>, covariances: &mut DVector<f64>) {
77+
let num_samples = data.len();
78+
let num_components = weights.len();
79+
80+
// Update weights
81+
for j in 0..num_components {
82+
weights[j] = responsibilities.column(j).sum() / num_samples as f64;
83+
}
84+
85+
// Update means
86+
for j in 0..num_components {
87+
means[j] = responsibilities.column(j).dot(&data) / responsibilities.column(j).sum();
88+
}
89+
90+
// Update covariances
91+
for j in 0..num_components {
92+
let covariance_sum = responsibilities.column(j).dot(&(data - means[j]).map(|x| x.powi(2))) / responsibilities.column(j).sum();
93+
covariances[j] = covariance_sum.max(1e-6); // Ensure covariance is not too small
94+
}
95+
}
96+
97+
fn gaussian_pdf(x: f64, mean: f64, covariance: f64) -> f64 {
98+
let exponent = -((x - mean).powi(2) / (2.0 * covariance));
99+
let coefficient = 1.0 / ((2.0 * std::f64::consts::PI * covariance).sqrt());
100+
coefficient * f64::exp(exponent)
101+
}

0 commit comments

Comments
 (0)