Skip to content

Commit 9d443ab

Browse files
Merge pull request #15 from lattice-based-cryptography/structure_KEM_class
refactor KEM struct
2 parents 8cd25ba + 3f5a65e commit 9d443ab

File tree

6 files changed

+114
-7
lines changed

6 files changed

+114
-7
lines changed

.github/workflows/basic.yml

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
name: basic
2+
3+
on:
4+
push:
5+
branches: [ "main" ]
6+
pull_request:
7+
branches: [ "main" ]
8+
9+
env:
10+
CARGO_TERM_COLOR: always
11+
12+
jobs:
13+
build:
14+
runs-on: ubuntu-latest
15+
16+
steps:
17+
- uses: actions/checkout@v4
18+
19+
# Install the Rust toolchain
20+
- name: Install Rust toolchain
21+
uses: dtolnay/rust-toolchain@stable
22+
23+
# Cache Cargo dependencies to speed up builds
24+
- name: Cache Cargo dependencies
25+
uses: actions/cache@v3
26+
with:
27+
path: ~/.cargo/registry
28+
key: cargo-${{ runner.os }}-${{ hashFiles('**/Cargo.lock') }}
29+
restore-keys: |
30+
cargo-${{ runner.os }}-
31+
32+
# Build the project
33+
- name: Build
34+
run: cargo build --verbose
35+
36+
# Run the tests
37+
- name: Run Tests
38+
run: cargo test --verbose

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
.DS_Store
2+
13
# Generated by Cargo
24
# will have compiled files and executables
35
debug/

src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
pub mod kem;
1+
pub mod ml_kem;
22
pub mod utils;

src/main.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
use ml_kem::kem::KEM;
1+
use ml_kem::ml_kem::MLKEM;
22
use ml_kem::utils::Parameters;
33
mod tests;
44

55
fn main() {
66
let params = Parameters::default();
77

88
// Generate key pair
9-
let (public_key, secret_key) = KEM::keygen(&params);
9+
let mlkem = MLKEM::new(params);
10+
let (public_key, secret_key) = mlkem.keygen();
1011

1112
// Print keys for verification
1213
println!("Public Key: {:?}", public_key);

src/ml_kem.rs

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
use module_lwe::utils::{gen_uniform_matrix,mul_mat_vec_simple,gen_small_vector,add_vec};
2+
use module_lwe::encrypt::encrypt;
3+
use module_lwe::decrypt::decrypt;
4+
use ring_lwe::utils::gen_binary_poly;
5+
use crate::utils::{Parameters, hash};
6+
use polynomial_ring::Polynomial;
7+
8+
pub struct MLKEM {
9+
params: Parameters,
10+
}
11+
12+
impl MLKEM {
13+
// Constructor to initialize MLKEM with parameters
14+
pub fn new(params: Parameters) -> Self {
15+
MLKEM { params } // Corrected: properly initializes and returns the struct
16+
}
17+
18+
pub fn keygen(&self) -> ((Vec<Vec<Polynomial<i64>>>, Vec<Polynomial<i64>>), Vec<Polynomial<i64>>) {
19+
let a = gen_uniform_matrix(self.params.n, self.params.k, self.params.q, None);
20+
21+
let s = gen_small_vector(self.params.n, self.params.k, None);
22+
let e = gen_small_vector(self.params.n, self.params.k, None);
23+
24+
let b = add_vec(
25+
&mul_mat_vec_simple(&a, &s, self.params.q, &self.params.f, self.params.omega),
26+
&e,
27+
self.params.q,
28+
&self.params.f
29+
);
30+
31+
((a, b), s)
32+
}
33+
34+
pub fn encapsulate(&self, pk: (Vec<Vec<Polynomial<i64>>>, Vec<Polynomial<i64>>)) -> (String, (Vec<Polynomial<i64>>, Polynomial<i64>)) {
35+
let params_mlwe = module_lwe::utils::Parameters {
36+
n: self.params.n,
37+
q: self.params.q,
38+
k: self.params.k,
39+
omega: self.params.omega,
40+
f: self.params.f.clone()
41+
};
42+
43+
let mut m = gen_binary_poly(self.params.n, None).coeffs().to_vec();
44+
m.resize(self.params.n, 0);
45+
46+
let ct = encrypt(&pk.0, &pk.1, &m, &params_mlwe, None);
47+
let k = hash(m);
48+
(k, ct)
49+
}
50+
51+
pub fn decapsulate(&self, sk: Vec<Polynomial<i64>>, ct: (Vec<Polynomial<i64>>, Polynomial<i64>)) -> String {
52+
let params_mlwe = module_lwe::utils::Parameters {
53+
n: self.params.n,
54+
q: self.params.q,
55+
k: self.params.k,
56+
omega: self.params.omega,
57+
f: self.params.f.clone()
58+
};
59+
60+
let mut m = decrypt(&sk, &ct.0, &ct.1, &params_mlwe);
61+
m.resize(self.params.n, 0);
62+
63+
hash(m)
64+
}
65+
}

src/tests.rs

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
#[cfg(test)] // This makes the following module compile only during tests
22
mod tests {
33
use ml_kem::utils::Parameters;
4-
use ml_kem::kem::KEM;
4+
use ml_kem::ml_kem::MLKEM;
55

66
// Test for basic keygen/encapsulate/decapsulate
77
#[test]
88
pub fn test_basic() {
99
let params = Parameters::default(); // Adjust this if needed
10-
let (pk, sk) = KEM::keygen(&params);
11-
let (k, ct) = KEM::encapsulate(pk, &params);
12-
let k_recovered = KEM::decapsulate(sk, ct, &params);
10+
let mlkem = MLKEM::new(params);
11+
let (pk, sk) = mlkem.keygen();
12+
let (k, ct) = mlkem.encapsulate(pk);
13+
let k_recovered = mlkem.decapsulate(sk, ct);
1314
assert_eq!(k, k_recovered, "test failed: {} != {}", k, k_recovered);
1415
}
1516
}

0 commit comments

Comments
 (0)