Skip to content

Commit

Permalink
flake: fix rust src for neovim, format examples
Browse files Browse the repository at this point in the history
  • Loading branch information
jhvst committed Aug 6, 2023
1 parent 4842fca commit 5ddd210
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 61 deletions.
35 changes: 21 additions & 14 deletions examples/reduce.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
use rivi_loader::{DebugOption, PushConstant, Task, GroupCount, Vulkan};
use rivi_loader::{DebugOption, GroupCount, PushConstant, Task, Vulkan};

fn main() {

let vk = Vulkan::new(DebugOption::None).unwrap();
let gpus = vk.compute.as_ref().unwrap();

loop {

let binary = &include_bytes!("./reduce/reduce.spv")[..];
let module = rspirv::dr::load_bytes(binary).unwrap();

Expand All @@ -20,26 +18,35 @@ fn main() {

let vec4 = 4;
let mut tasks = vec![Task {
input: vec![
vec![1.0f32; gpu.subgroup_size * gpu.subgroup_size * vec4],
],
input: vec![vec![1.0f32; gpu.subgroup_size * gpu.subgroup_size * vec4]],
output: vec![0.0f32; 4096],
push_constants: vec![
PushConstant { offset: 0, constants: vec![2] },
],
push_constants: vec![PushConstant {
offset: 0,
constants: vec![2],
}],
queue: *queue,
group_count: GroupCount { ..Default::default() },
group_count: GroupCount { x: 1, y: 1, z: 1 },
}];

let run_timer = std::time::Instant::now();
gpu.scheduled(&shader, queue_family, &mut tasks).unwrap();
let end_timer = run_timer.elapsed().as_micros();

let task = tasks.first().unwrap();
println!("Queue Family {}, Queue {:?}: {:?} in {}qs", queue_family.phy_index, task.queue, task.output[0], end_timer);
assert_eq!(task.output[0], (gpu.subgroup_size * gpu.subgroup_size * vec4) as f32);
let errors = task.output.iter().enumerate().filter(|(i, v)| (*v).ne(&0f32) && i.ne(&0)).collect::<Vec<_>>();
println!(
"Queue Family {}, Queue {:?}: {:?} in {}qs",
queue_family.phy_index, task.queue, task.output[0], end_timer
);
assert_eq!(
task.output[0],
(gpu.subgroup_size * gpu.subgroup_size * vec4) as f32
);
let errors = task
.output
.iter()
.enumerate()
.filter(|(i, v)| (*v).ne(&0f32) && i.ne(&0))
.collect::<Vec<_>>();
println!("Errors: {:?}", errors)

}
}
106 changes: 59 additions & 47 deletions examples/rf.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::error::Error;

use rivi_loader::{DebugOption, GroupCount, Task, Vulkan};
use rayon::prelude::*;
use rivi_loader::{DebugOption, GroupCount, Task, Vulkan};

/// `rf.rs` runs Python Scikit derived random forest prediction algorithm.
/// The implementation of this algorithm was derived from Python/Cython to APL, and
Expand All @@ -24,7 +24,6 @@ fn main() {
}

fn batched(vk: &rivi_loader::Vulkan, shader: &rspirv::dr::Module) -> u128 {

let gpu = vk.compute.as_ref().unwrap().first().unwrap();
let threads = gpu.fences.as_ref().unwrap().len();

Expand All @@ -34,45 +33,54 @@ fn batched(vk: &rivi_loader::Vulkan, shader: &rspirv::dr::Module) -> u128 {
// create upper bound for iterations
let bound = (150.0 / threads as f32).ceil() as i32;

(0..bound).map(|_| {

let specializations = Vec::new();
let shader = rivi_loader::load_shader(gpu, shader.clone(), specializations).unwrap();

let time = gpu.fences.as_ref().unwrap().par_iter().map(|fence| {

let mut tasks = fence.queues.iter().map(|queue| {
Task {
input: dataset[bound as usize].clone(),
output: vec![0.0f32; 1_146_024],
push_constants: vec![],
queue: *queue,
group_count: GroupCount {
x: 1024,
..Default::default()
},
}
})
.collect::<Vec<_>>();

let run_timer = std::time::Instant::now();
gpu.scheduled(&shader, fence, &mut tasks).unwrap();
let end_timer = run_timer.elapsed().as_millis();

tasks.into_iter().for_each(|t| assert_eq!(t.output.into_iter().map(|f| f as f64).sum::<f64>(), 490058.0_f64));

end_timer
}).collect::<Vec<_>>();

time.iter().sum::<u128>() / gpu.fences.as_ref().unwrap().len() as u128

}).sum()
(0..bound)
.map(|_| {
let specializations = Vec::new();
let shader = rivi_loader::load_shader(gpu, shader.clone(), specializations).unwrap();

let time = gpu
.fences
.as_ref()
.unwrap()
.par_iter()
.map(|fence| {
let mut tasks = fence
.queues
.iter()
.map(|queue| Task {
input: dataset[bound as usize].clone(),
output: vec![0.0f32; 1_146_024],
push_constants: vec![],
queue: *queue,
group_count: GroupCount {
x: 1024,
..Default::default()
},
})
.collect::<Vec<_>>();

let run_timer = std::time::Instant::now();
gpu.scheduled(&shader, fence, &mut tasks).unwrap();
let end_timer = run_timer.elapsed().as_millis();

tasks.into_iter().for_each(|t| {
assert_eq!(
t.output.into_iter().map(|f| f as f64).sum::<f64>(),
490058.0_f64
)
});

end_timer
})
.collect::<Vec<_>>();

time.iter().sum::<u128>() / gpu.fences.as_ref().unwrap().len() as u128
})
.sum()
}

fn csv(f: &str, v: &mut Vec<f32>) -> Result<(), Box<dyn Error>> {
let mut reader = csv::ReaderBuilder::new()
.has_headers(false)
.from_path(f)?;
let mut reader = csv::ReaderBuilder::new().has_headers(false).from_path(f)?;
for record in reader.records() {
let record = record?;
for field in record.iter() {
Expand All @@ -84,7 +92,6 @@ fn csv(f: &str, v: &mut Vec<f32>) -> Result<(), Box<dyn Error>> {
}

fn load_input(chunks: usize) -> Vec<Vec<Vec<f32>>> {

let mut feature: Vec<f32> = Vec::new();
if let Err(err) = csv("examples/rf/dataset/feature.csv", &mut feature) {
panic!("error running example: {}", err);
Expand Down Expand Up @@ -115,12 +122,17 @@ fn load_input(chunks: usize) -> Vec<Vec<Vec<f32>>> {
panic!("error running example: {}", err);
}

(0..chunks).into_iter().map(|_| vec![
left.clone(),
right.clone(),
th.clone(),
feature.clone(),
values.clone(),
x.clone()
]).collect()
(0..chunks)
.into_iter()
.map(|_| {
vec![
left.clone(),
right.clone(),
th.clone(),
feature.clone(),
values.clone(),
x.clone(),
]
})
.collect()
}
1 change: 1 addition & 0 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
in
{
devShells.default = pkgs.mkShell {
RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}";
inherit nativeBuildInputs;
packages = with pkgs; [
rustc
Expand Down

0 comments on commit 5ddd210

Please sign in to comment.