Skip to content

Commit

Permalink
Auto merge of #332 - khodzha:oversampling_speex, r=Manishearth
Browse files Browse the repository at this point in the history
2x and 4x oversampling for WaveShaperNode

Issue #205
Finally managed to defeat the noise I mentioned in the issue

I'm not sure if ditching FrameIterator and operating on raw vecs was a good thing 🤷‍♂️

Also spec mentions tail-time but I'm not sure how to handle that
  • Loading branch information
bors-servo committed Feb 15, 2020
2 parents bc8c7f7 + a8a1f91 commit f992e21
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 32 deletions.
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions audio/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ servo_media_derive = { path = "../servo-media-derive" }
servo-media-player = { path = "../player" }
servo-media-traits = { path = "../traits" }
smallvec = "0.6.1"
speexdsp-resampler = "0.1.0"

[dependencies.petgraph]
version = "0.4.12"
Expand Down
1 change: 1 addition & 0 deletions audio/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ extern crate euclid;
extern crate num_traits;
extern crate petgraph;
extern crate smallvec;
extern crate speexdsp_resampler;
#[macro_use]
pub mod macros;
extern crate servo_media_traits;
Expand Down
192 changes: 162 additions & 30 deletions audio/wave_shaper_node.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use block::Chunk;
use block::{Chunk, FRAMES_PER_BLOCK_USIZE};
use node::{AudioNodeEngine, AudioNodeType, BlockInfo, ChannelInfo};
use speexdsp_resampler::State as SpeexResamplerState;

#[derive(Clone, Debug, PartialEq)]
pub enum OverSampleType {
Expand All @@ -8,6 +9,25 @@ pub enum OverSampleType {
Quadruple,
}

#[derive(Clone, Debug, PartialEq)]
enum TailtimeBlocks {
Zero,
One,
Two,
}

const OVERSAMPLING_QUALITY: usize = 0;

impl OverSampleType {
fn value(&self) -> usize {
match self {
OverSampleType::None => 1,
OverSampleType::Double => 2,
OverSampleType::Quadruple => 4,
}
}
}

type WaveShaperCurve = Option<Vec<f32>>;

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -35,10 +55,11 @@ pub(crate) struct WaveShaperNode {
curve_set: bool,
curve: WaveShaperCurve,
#[allow(dead_code)]
// TODO implement tail-time based on the oversample attribute.
// https://github.com/servo/media/issues/205
oversample: OverSampleType,
channel_info: ChannelInfo,
upsampler: Option<SpeexResamplerState>,
downsampler: Option<SpeexResamplerState>,
tailtime_blocks_left: TailtimeBlocks,
}

impl WaveShaperNode {
Expand All @@ -49,15 +70,15 @@ impl WaveShaperNode {
"WaveShaperNode curve must have length of 2 or more"
)
}
if options.oversample != OverSampleType::None {
unimplemented!("No oversampling for WaveShaperNode yet");
}

Self {
curve_set: options.curve.is_some(),
curve: options.curve,
oversample: options.oversample,
channel_info,
upsampler: None,
downsampler: None,
tailtime_blocks_left: TailtimeBlocks::Zero,
}
}

Expand All @@ -79,40 +100,151 @@ impl AudioNodeEngine for WaveShaperNode {
AudioNodeType::WaveShaperNode
}

fn process(&mut self, mut inputs: Chunk, _info: &BlockInfo) -> Chunk {
fn process(&mut self, mut inputs: Chunk, info: &BlockInfo) -> Chunk {
debug_assert!(inputs.len() == 1);

if inputs.blocks[0].is_silence() {
if self.curve.is_none() {
return inputs;
}

if let Some(curve) = &self.curve {
let mut iter = inputs.blocks[0].iter();

while let Some(mut frame) = iter.next() {
frame.mutate_with(|sample, _| {
let len = curve.len();
let curve_index: f32 = ((len - 1) as f32) * (*sample + 1.) / 2.;

if curve_index <= 0. {
*sample = curve[0];
} else if curve_index >= len as f32 {
*sample = curve[len - 1];
} else {
let index_lo = curve_index as usize;
let index_hi = index_lo + 1;
let interp_factor: f32 = curve_index - index_lo as f32;
*sample = (1. - interp_factor) * curve[index_lo]
+ interp_factor * curve[index_hi];
}
});
let curve = &self.curve.as_ref().expect("Just checked for is_none()");

if inputs.blocks[0].is_silence() {
if WaveShaperNode::silence_produces_nonsilent_output(curve) {
inputs.blocks[0].explicit_silence();
self.tailtime_blocks_left = TailtimeBlocks::Two;
} else if self.tailtime_blocks_left != TailtimeBlocks::Zero {
inputs.blocks[0].explicit_silence();

self.tailtime_blocks_left = match self.tailtime_blocks_left {
TailtimeBlocks::Zero => TailtimeBlocks::Zero,
TailtimeBlocks::One => TailtimeBlocks::Zero,
TailtimeBlocks::Two => TailtimeBlocks::One,
}
} else {
return inputs;
}
} else {
self.tailtime_blocks_left = TailtimeBlocks::Two;
}

inputs
let block = &mut inputs.blocks[0];
let channels = block.chan_count();

if self.oversample != OverSampleType::None {
let rate: usize = info.sample_rate as usize;
let sampling_factor = self.oversample.value();

if self.upsampler.is_none() {
self.upsampler = Some(
SpeexResamplerState::new(
channels as usize,
rate,
rate * sampling_factor,
OVERSAMPLING_QUALITY,
)
.expect("Couldnt create upsampler"),
);
};

if self.downsampler.is_none() {
self.downsampler = Some(
SpeexResamplerState::new(
channels as usize,
rate * sampling_factor,
rate,
OVERSAMPLING_QUALITY,
)
.expect("Couldnt create downsampler"),
);
};

let mut upsampler = self.upsampler.as_mut().unwrap();
let mut downsampler = self.downsampler.as_mut().unwrap();

let mut oversampled_buffer: Vec<f32> =
vec![0.; FRAMES_PER_BLOCK_USIZE * sampling_factor];

for chan in 0..channels {
let out_len = WaveShaperNode::resample(
&mut upsampler,
chan,
block.data_chan(chan),
&mut oversampled_buffer,
);

debug_assert!(
out_len == 128 * sampling_factor,
"Expected {} samples in output after upsampling, got: {}",
128 * sampling_factor,
out_len
);

WaveShaperNode::apply_curve(&mut oversampled_buffer, &curve);

let out_len = WaveShaperNode::resample(
&mut downsampler,
chan,
&oversampled_buffer,
&mut block.data_chan_mut(chan),
);

debug_assert!(
out_len == 128,
"Expected 128 samples in output after downsampling, got {}",
out_len
);
}
} else {
inputs
WaveShaperNode::apply_curve(block.data_mut(), &curve);
}

inputs
}

make_message_handler!(WaveShaperNode: handle_waveshaper_message);
}

impl WaveShaperNode {
fn silence_produces_nonsilent_output(curve: &Vec<f32>) -> bool {
let len = curve.len();
let len_halved = ((len - 1) as f32) / 2.;
let curve_index: f32 = len_halved;
let index_lo = curve_index as usize;
let index_hi = index_lo + 1;
let interp_factor: f32 = curve_index - index_lo as f32;
let shaped_val = (1. - interp_factor) * curve[index_lo] + interp_factor * curve[index_hi];
shaped_val == 0.0
}

fn apply_curve(buf: &mut [f32], curve: &Vec<f32>) {
let len = curve.len();
let len_halved = ((len - 1) as f32) / 2.;
buf.iter_mut().for_each(|sample| {
let curve_index: f32 = len_halved * (*sample + 1.);

if curve_index <= 0. {
*sample = curve[0];
} else if curve_index >= (len - 1) as f32 {
*sample = curve[len - 1];
} else {
let index_lo = curve_index as usize;
let index_hi = index_lo + 1;
let interp_factor: f32 = curve_index - index_lo as f32;
*sample = (1. - interp_factor) * curve[index_lo] + interp_factor * curve[index_hi];
}
});
}

fn resample(
st: &mut SpeexResamplerState,
chan: u8,
input: &[f32],
output: &mut [f32],
) -> usize {
let (_in_len, out_len) = st
.process_float(chan as usize, input, output)
.expect("Resampling failed");
out_len
}
}
40 changes: 38 additions & 2 deletions examples/wave_shaper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ fn run_example(servo_media: Arc<ServoMedia>) {

{
let context = context.lock().unwrap();
let curve = vec![1., 0., 0., 0.75, 0.5];

let dest = context.dest_node();
let osc = context.create_node(
Expand All @@ -24,11 +25,25 @@ fn run_example(servo_media: Arc<ServoMedia>) {
);
let wsh = context.create_node(
AudioNodeInit::WaveShaperNode(WaveShaperNodeOptions {
curve: Some(vec![0., 0., 0., 4., 6.]),
curve: Some(curve.clone()),
oversample: OverSampleType::None,
}),
Default::default(),
);
let wshx2 = context.create_node(
AudioNodeInit::WaveShaperNode(WaveShaperNodeOptions {
curve: Some(curve.clone()),
oversample: OverSampleType::Double,
}),
Default::default(),
);
let wshx4 = context.create_node(
AudioNodeInit::WaveShaperNode(WaveShaperNodeOptions {
curve: Some(curve.clone()),
oversample: OverSampleType::Quadruple,
}),
Default::default(),
);

context.connect_ports(osc.output(0), dest.input(0));
let _ = context.resume();
Expand All @@ -40,8 +55,29 @@ fn run_example(servo_media: Arc<ServoMedia>) {
println!("raw oscillator");
thread::sleep(time::Duration::from_millis(2000));

println!("oscillator through waveshaper");
println!("oscillator through waveshaper with no oversampling");
context.disconnect_output(osc.output(0));
context.connect_ports(osc.output(0), wsh.input(0));
context.connect_ports(wsh.output(0), dest.input(0));
thread::sleep(time::Duration::from_millis(2000));

println!("oscillator through waveshaper with 2x oversampling");
context.disconnect_output(osc.output(0));
context.disconnect_output(wsh.output(0));
context.connect_ports(osc.output(0), wshx2.input(0));
context.connect_ports(wshx2.output(0), dest.input(0));
thread::sleep(time::Duration::from_millis(2000));

println!("oscillator through waveshaper with 4x oversampling");
context.disconnect_output(osc.output(0));
context.disconnect_output(wshx2.output(0));
context.connect_ports(osc.output(0), wshx4.input(0));
context.connect_ports(wshx4.output(0), dest.input(0));
thread::sleep(time::Duration::from_millis(2000));

println!("oscillator through waveshaper with no oversampling");
context.disconnect_output(osc.output(0));
context.disconnect_output(wshx4.output(0));
context.connect_ports(osc.output(0), wsh.input(0));
context.connect_ports(wsh.output(0), dest.input(0));
thread::sleep(time::Duration::from_millis(2000));
Expand Down

0 comments on commit f992e21

Please sign in to comment.