From 6e1af049fc241dda48f1602010e428291b859fef Mon Sep 17 00:00:00 2001 From: Hugo Tunius Date: Fri, 17 Jun 2022 14:41:00 +0100 Subject: [PATCH] Improve cancel safety This makes the following methods cancel safe: * `report::receiver::receiver_stream::Receiverstream::read` * `nack::generator::generator_stream::GeneratorStream::read` This was achieved mostly by making locks(`tokio::Mutex`) that were async not be. Since none of these locks are held across `.await` points there's no need for them to be async locks. --- src/mock/mock_time.rs | 14 +-- src/nack/generator/generator_stream.rs | 13 ++- src/nack/generator/mod.rs | 2 +- src/report/mod.rs | 5 +- src/report/receiver/mod.rs | 12 ++- src/report/receiver/receiver_stream.rs | 21 ++--- src/report/receiver/receiver_test.rs | 121 +++++++++---------------- src/report/sender/mod.rs | 5 +- src/report/sender/sender_stream.rs | 2 +- src/report/sender/sender_test.rs | 45 ++++----- 10 files changed, 101 insertions(+), 139 deletions(-) diff --git a/src/mock/mock_time.rs b/src/mock/mock_time.rs index 3ed1e0a..b7bc9f8 100644 --- a/src/mock/mock_time.rs +++ b/src/mock/mock_time.rs @@ -1,5 +1,5 @@ +use std::sync::Mutex; use std::time::{Duration, SystemTime}; -use tokio::sync::Mutex; /// MockTime is a helper to replace SystemTime::now() for testing purposes. pub struct MockTime { @@ -16,20 +16,20 @@ impl Default for MockTime { impl MockTime { /// set_now sets the current time. - pub async fn set_now(&self, now: SystemTime) { - let mut cur_now = self.cur_now.lock().await; + pub fn set_now(&self, now: SystemTime) { + let mut cur_now = self.cur_now.lock().unwrap(); *cur_now = now; } /// now returns the current time. - pub async fn now(&self) -> SystemTime { - let cur_now = self.cur_now.lock().await; + pub fn now(&self) -> SystemTime { + let cur_now = self.cur_now.lock().unwrap(); *cur_now } /// advance advances duration d - pub async fn advance(&mut self, d: Duration) { - let mut cur_now = self.cur_now.lock().await; + pub fn advance(&mut self, d: Duration) { + let mut cur_now = self.cur_now.lock().unwrap(); *cur_now = cur_now.checked_add(d).unwrap_or(*cur_now); } } diff --git a/src/nack/generator/generator_stream.rs b/src/nack/generator/generator_stream.rs index 3ca388e..fdf0b79 100644 --- a/src/nack/generator/generator_stream.rs +++ b/src/nack/generator/generator_stream.rs @@ -1,4 +1,7 @@ +use std::sync::Mutex; + use super::*; + use crate::nack::UINT16SIZE_HALF; use util::Unmarshal; @@ -135,13 +138,13 @@ impl GeneratorStream { } } - pub(super) async fn missing_seq_numbers(&self, skip_last_n: u16) -> Vec { - let internal = self.internal.lock().await; + pub(super) fn missing_seq_numbers(&self, skip_last_n: u16) -> Vec { + let internal = self.internal.lock().unwrap(); internal.missing_seq_numbers(skip_last_n) } - pub(super) async fn add(&self, seq: u16) { - let mut internal = self.internal.lock().await; + pub(super) fn add(&self, seq: u16) { + let mut internal = self.internal.lock().unwrap(); internal.add(seq); } } @@ -155,7 +158,7 @@ impl RTPReader for GeneratorStream { let mut b = &buf[..n]; let pkt = rtp::packet::Packet::unmarshal(&mut b)?; - self.add(pkt.header.sequence_number).await; + self.add(pkt.header.sequence_number); Ok((n, attr)) } diff --git a/src/nack/generator/mod.rs b/src/nack/generator/mod.rs index 85b4f6a..5c0c9db 100644 --- a/src/nack/generator/mod.rs +++ b/src/nack/generator/mod.rs @@ -132,7 +132,7 @@ impl Generator { let mut nacks = vec![]; let streams = internal.streams.lock().await; for (ssrc, stream) in streams.iter() { - let missing = stream.missing_seq_numbers(internal.skip_last_n).await; + let missing = stream.missing_seq_numbers(internal.skip_last_n); if missing.is_empty(){ continue; } diff --git a/src/report/mod.rs b/src/report/mod.rs index 4690639..374709a 100644 --- a/src/report/mod.rs +++ b/src/report/mod.rs @@ -1,7 +1,6 @@ -use rtp::packetizer::FnTimeGen; use std::collections::HashMap; use std::sync::Arc; -use std::time::Duration; +use std::time::{Duration, SystemTime}; use tokio::sync::{mpsc, Mutex}; use waitgroup::WaitGroup; @@ -13,6 +12,8 @@ use crate::{Interceptor, InterceptorBuilder}; use receiver::{ReceiverReport, ReceiverReportInternal}; use sender::{SenderReport, SenderReportInternal}; +type FnTimeGen = Arc SystemTime + Sync + 'static + Send>; + /// ReceiverBuilder can be used to configure ReceiverReport Interceptor. #[derive(Default)] pub struct ReportBuilder { diff --git a/src/report/receiver/mod.rs b/src/report/receiver/mod.rs index ed20b5f..4425bce 100644 --- a/src/report/receiver/mod.rs +++ b/src/report/receiver/mod.rs @@ -33,7 +33,7 @@ impl RTCPReader for ReceiverReportRtcpReader { let pkts = rtcp::packet::unmarshal(&mut b)?; let now = if let Some(f) = &self.internal.now { - f().await + f() } else { SystemTime::now() }; @@ -48,7 +48,7 @@ impl RTCPReader for ReceiverReportRtcpReader { m.get(&sr.ssrc).cloned() }; if let Some(stream) = stream { - stream.process_sender_report(now, sr).await; + stream.process_sender_report(now, sr); } } } @@ -96,9 +96,11 @@ impl ReceiverReport { loop { tokio::select! { _ = ticker.tick() =>{ + // TODO(cancel safety): This branch isn't cancel safe + let now = if let Some(f) = &internal.now { - f().await - }else{ + f() + } else { SystemTime::now() }; let streams:Vec> = { @@ -106,7 +108,7 @@ impl ReceiverReport { m.values().cloned().collect() }; for stream in streams { - let pkt = stream.generate_report(now).await; + let pkt = stream.generate_report(now); let a = Attributes::new(); if let Err(err) = rtcp_writer.write(&[Box::new(pkt)], &a).await{ diff --git a/src/report/receiver/receiver_stream.rs b/src/report/receiver/receiver_stream.rs index fab003f..06ac9bb 100644 --- a/src/report/receiver/receiver_stream.rs +++ b/src/report/receiver/receiver_stream.rs @@ -2,7 +2,7 @@ use super::*; use crate::{Attributes, RTPReader}; use async_trait::async_trait; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use std::time::SystemTime; use util::Unmarshal; @@ -184,25 +184,22 @@ impl ReceiverStream { } } - pub(crate) async fn process_rtp(&self, now: SystemTime, pkt: &rtp::packet::Packet) { - let mut internal = self.internal.lock().await; + pub(crate) fn process_rtp(&self, now: SystemTime, pkt: &rtp::packet::Packet) { + let mut internal = self.internal.lock().unwrap(); internal.process_rtp(now, pkt); } - pub(crate) async fn process_sender_report( + pub(crate) fn process_sender_report( &self, now: SystemTime, sr: &rtcp::sender_report::SenderReport, ) { - let mut internal = self.internal.lock().await; + let mut internal = self.internal.lock().unwrap(); internal.process_sender_report(now, sr); } - pub(crate) async fn generate_report( - &self, - now: SystemTime, - ) -> rtcp::receiver_report::ReceiverReport { - let mut internal = self.internal.lock().await; + pub(crate) fn generate_report(&self, now: SystemTime) -> rtcp::receiver_report::ReceiverReport { + let mut internal = self.internal.lock().unwrap(); internal.generate_report(now) } } @@ -217,11 +214,11 @@ impl RTPReader for ReceiverStream { let mut b = &buf[..n]; let pkt = rtp::packet::Packet::unmarshal(&mut b)?; let now = if let Some(f) = &self.now { - f().await + f() } else { SystemTime::now() }; - self.process_rtp(now, &pkt).await; + self.process_rtp(now, &pkt); Ok((n, attr)) } diff --git a/src/report/receiver/receiver_test.rs b/src/report/receiver/receiver_test.rs index 3d63f2c..35e3bd7 100644 --- a/src/report/receiver/receiver_test.rs +++ b/src/report/receiver/receiver_test.rs @@ -10,13 +10,10 @@ use std::pin::Pin; #[tokio::test] async fn test_receiver_interceptor_before_any_packet() -> Result<()> { let mt = Arc::new(MockTime::default()); - let mt2 = Arc::clone(&mt); - let time_gen = Arc::new( - move || -> Pin + Send + 'static>> { - let mt3 = Arc::clone(&mt2); - Box::pin(async move { mt3.now().await }) - }, - ); + let time_gen = { + let mt = Arc::clone(&mt); + Arc::new(move || mt.now()) + }; let icpr: Arc = ReceiverReport::builder() .with_interval(Duration::from_millis(50)) @@ -65,13 +62,10 @@ async fn test_receiver_interceptor_before_any_packet() -> Result<()> { #[tokio::test] async fn test_receiver_interceptor_after_rtp_packets() -> Result<()> { let mt = Arc::new(MockTime::default()); - let mt2 = Arc::clone(&mt); - let time_gen = Arc::new( - move || -> Pin + Send + 'static>> { - let mt3 = Arc::clone(&mt2); - Box::pin(async move { mt3.now().await }) - }, - ); + let time_gen = { + let mt = Arc::clone(&mt); + Arc::new(move || mt.now()) + }; let icpr: Arc = ReceiverReport::builder() .with_interval(Duration::from_millis(50)) @@ -133,13 +127,10 @@ async fn test_receiver_interceptor_after_rtp_and_rtcp_packets() -> Result<()> { let rtp_time: SystemTime = Utc.ymd(2009, 10, 23).and_hms(0, 0, 0).into(); let mt = Arc::new(MockTime::default()); - let mt2 = Arc::clone(&mt); - let time_gen = Arc::new( - move || -> Pin + Send + 'static>> { - let mt3 = Arc::clone(&mt2); - Box::pin(async move { mt3.now().await }) - }, - ); + let time_gen = { + let mt = Arc::clone(&mt); + Arc::new(move || mt.now()) + }; let icpr: Arc = ReceiverReport::builder() .with_interval(Duration::from_millis(50)) @@ -218,12 +209,10 @@ async fn test_receiver_interceptor_after_rtp_and_rtcp_packets() -> Result<()> { async fn test_receiver_interceptor_overflow() -> Result<()> { let mt = Arc::new(MockTime::default()); let mt2 = Arc::clone(&mt); - let time_gen = Arc::new( - move || -> Pin + Send + 'static>> { - let mt3 = Arc::clone(&mt2); - Box::pin(async move { mt3.now().await }) - }, - ); + let time_gen = { + let mt = Arc::clone(&mt); + Arc::new(move || mt.now()) + }; let icpr: Arc = ReceiverReport::builder() .with_interval(Duration::from_millis(50)) @@ -290,13 +279,10 @@ async fn test_receiver_interceptor_overflow() -> Result<()> { #[tokio::test] async fn test_receiver_interceptor_overflow_five_pkts() -> Result<()> { let mt = Arc::new(MockTime::default()); - let mt2 = Arc::clone(&mt); - let time_gen = Arc::new( - move || -> Pin + Send + 'static>> { - let mt3 = Arc::clone(&mt2); - Box::pin(async move { mt3.now().await }) - }, - ); + let time_gen = { + let mt = Arc::clone(&mt); + Arc::new(move || mt.now()) + }; let icpr: Arc = ReceiverReport::builder() .with_interval(Duration::from_millis(50)) @@ -395,13 +381,10 @@ async fn test_receiver_interceptor_packet_loss() -> Result<()> { let rtp_time: SystemTime = Utc.ymd(2009, 11, 10).and_hms(23, 0, 0).into(); let mt = Arc::new(MockTime::default()); - let mt2 = Arc::clone(&mt); - let time_gen = Arc::new( - move || -> Pin + Send + 'static>> { - let mt3 = Arc::clone(&mt2); - Box::pin(async move { mt3.now().await }) - }, - ); + let time_gen = { + let mt = Arc::clone(&mt); + Arc::new(move || mt.now()) + }; let icpr: Arc = ReceiverReport::builder() .with_interval(Duration::from_millis(50)) @@ -509,13 +492,10 @@ async fn test_receiver_interceptor_packet_loss() -> Result<()> { #[tokio::test] async fn test_receiver_interceptor_overflow_and_packet_loss() -> Result<()> { let mt = Arc::new(MockTime::default()); - let mt2 = Arc::clone(&mt); - let time_gen = Arc::new( - move || -> Pin + Send + 'static>> { - let mt3 = Arc::clone(&mt2); - Box::pin(async move { mt3.now().await }) - }, - ); + let time_gen = { + let mt = Arc::clone(&mt); + Arc::new(move || mt.now()) + }; let icpr: Arc = ReceiverReport::builder() .with_interval(Duration::from_millis(50)) @@ -582,13 +562,10 @@ async fn test_receiver_interceptor_overflow_and_packet_loss() -> Result<()> { #[tokio::test] async fn test_receiver_interceptor_reordered_packets() -> Result<()> { let mt = Arc::new(MockTime::default()); - let mt2 = Arc::clone(&mt); - let time_gen = Arc::new( - move || -> Pin + Send + 'static>> { - let mt3 = Arc::clone(&mt2); - Box::pin(async move { mt3.now().await }) - }, - ); + let time_gen = { + let mt = Arc::clone(&mt); + Arc::new(move || mt.now()) + }; let icpr: Arc = ReceiverReport::builder() .with_interval(Duration::from_millis(50)) @@ -647,13 +624,10 @@ async fn test_receiver_interceptor_reordered_packets() -> Result<()> { #[tokio::test] async fn test_receiver_interceptor_jitter() -> Result<()> { let mt = Arc::new(MockTime::default()); - let mt2 = Arc::clone(&mt); - let time_gen = Arc::new( - move || -> Pin + Send + 'static>> { - let mt3 = Arc::clone(&mt2); - Box::pin(async move { mt3.now().await }) - }, - ); + let time_gen = { + let mt = Arc::clone(&mt); + Arc::new(move || mt.now()) + }; let icpr: Arc = ReceiverReport::builder() .with_interval(Duration::from_millis(50)) @@ -670,8 +644,7 @@ async fn test_receiver_interceptor_jitter() -> Result<()> { ) .await; - mt.set_now(Utc.ymd(2009, 11, 10).and_hms(23, 0, 0).into()) - .await; + mt.set_now(Utc.ymd(2009, 11, 10).and_hms(23, 0, 0).into()); stream .receive_rtp(rtp::packet::Packet { header: rtp::header::Header { @@ -684,8 +657,7 @@ async fn test_receiver_interceptor_jitter() -> Result<()> { .await; stream.read_rtp().await; - mt.set_now(Utc.ymd(2009, 11, 10).and_hms(23, 0, 1).into()) - .await; + mt.set_now(Utc.ymd(2009, 11, 10).and_hms(23, 0, 1).into()); stream .receive_rtp(rtp::packet::Packet { header: rtp::header::Header { @@ -727,13 +699,10 @@ async fn test_receiver_interceptor_jitter() -> Result<()> { #[tokio::test] async fn test_receiver_interceptor_delay() -> Result<()> { let mt = Arc::new(MockTime::default()); - let mt2 = Arc::clone(&mt); - let time_gen = Arc::new( - move || -> Pin + Send + 'static>> { - let mt3 = Arc::clone(&mt2); - Box::pin(async move { mt3.now().await }) - }, - ); + let time_gen = { + let mt = Arc::clone(&mt); + Arc::new(move || mt.now()) + }; let icpr: Arc = ReceiverReport::builder() .with_interval(Duration::from_millis(50)) @@ -750,8 +719,7 @@ async fn test_receiver_interceptor_delay() -> Result<()> { ) .await; - mt.set_now(Utc.ymd(2009, 11, 10).and_hms(23, 0, 0).into()) - .await; + mt.set_now(Utc.ymd(2009, 11, 10).and_hms(23, 0, 0).into()); stream .receive_rtcp(vec![Box::new(rtcp::sender_report::SenderReport { ssrc: 123456, @@ -764,8 +732,7 @@ async fn test_receiver_interceptor_delay() -> Result<()> { .await; stream.read_rtcp().await; - mt.set_now(Utc.ymd(2009, 11, 10).and_hms(23, 0, 1).into()) - .await; + mt.set_now(Utc.ymd(2009, 11, 10).and_hms(23, 0, 1).into()); let pkts = stream.written_rtcp().await.unwrap(); assert_eq!(pkts.len(), 1); diff --git a/src/report/sender/mod.rs b/src/report/sender/mod.rs index b3277c7..3757c18 100644 --- a/src/report/sender/mod.rs +++ b/src/report/sender/mod.rs @@ -58,9 +58,10 @@ impl SenderReport { loop { tokio::select! { _ = ticker.tick() =>{ + // TODO(cancel safety): This branch isn't cancel safe let now = if let Some(f) = &internal.now { - f().await - }else{ + f() + } else { SystemTime::now() }; let streams:Vec> = { diff --git a/src/report/sender/sender_stream.rs b/src/report/sender/sender_stream.rs index 1daecba..bf1f15a 100644 --- a/src/report/sender/sender_stream.rs +++ b/src/report/sender/sender_stream.rs @@ -94,7 +94,7 @@ impl RTPWriter for SenderStream { /// write a rtp packet async fn write(&self, pkt: &rtp::packet::Packet, a: &Attributes) -> Result { let now = if let Some(f) = &self.now { - f().await + f() } else { SystemTime::now() }; diff --git a/src/report/sender/sender_test.rs b/src/report/sender/sender_test.rs index edf24b7..1f5a96d 100644 --- a/src/report/sender/sender_test.rs +++ b/src/report/sender/sender_test.rs @@ -10,13 +10,10 @@ use std::pin::Pin; #[tokio::test] async fn test_sender_interceptor_before_any_packet() -> Result<()> { let mt = Arc::new(MockTime::default()); - let mt2 = Arc::clone(&mt); - let time_gen = Arc::new( - move || -> Pin + Send + 'static>> { - let mt3 = Arc::clone(&mt2); - Box::pin(async move { mt3.now().await }) - }, - ); + let time_gen = { + let mt = Arc::clone(&mt); + Arc::new(move || mt.now()) + }; let icpr: Arc = SenderReport::builder() .with_interval(Duration::from_millis(50)) @@ -34,7 +31,7 @@ async fn test_sender_interceptor_before_any_packet() -> Result<()> { .await; let dt = Utc.ymd(2009, 10, 23).and_hms(0, 0, 0); - mt.set_now(dt.into()).await; + mt.set_now(dt.into()); let pkts = stream.written_rtcp().await.unwrap(); assert_eq!(pkts.len(), 1); @@ -45,7 +42,7 @@ async fn test_sender_interceptor_before_any_packet() -> Result<()> { assert_eq!( &rtcp::sender_report::SenderReport { ssrc: 123456, - ntp_time: unix2ntp(mt.now().await), + ntp_time: unix2ntp(mt.now()), rtp_time: 4294967295, // pion: 2269117121, packet_count: 0, octet_count: 0, @@ -65,13 +62,10 @@ async fn test_sender_interceptor_before_any_packet() -> Result<()> { #[tokio::test] async fn test_sender_interceptor_after_rtp_packets() -> Result<()> { let mt = Arc::new(MockTime::default()); - let mt2 = Arc::clone(&mt); - let time_gen = Arc::new( - move || -> Pin + Send + 'static>> { - let mt3 = Arc::clone(&mt2); - Box::pin(async move { mt3.now().await }) - }, - ); + let time_gen = { + let mt = Arc::clone(&mt); + Arc::new(move || mt.now()) + }; let icpr: Arc = SenderReport::builder() .with_interval(Duration::from_millis(50)) @@ -101,7 +95,7 @@ async fn test_sender_interceptor_after_rtp_packets() -> Result<()> { } let dt = Utc.ymd(2009, 10, 23).and_hms(0, 0, 0); - mt.set_now(dt.into()).await; + mt.set_now(dt.into()); let pkts = stream.written_rtcp().await.unwrap(); assert_eq!(pkts.len(), 1); @@ -112,7 +106,7 @@ async fn test_sender_interceptor_after_rtp_packets() -> Result<()> { assert_eq!( &rtcp::sender_report::SenderReport { ssrc: 123456, - ntp_time: unix2ntp(mt.now().await), + ntp_time: unix2ntp(mt.now()), rtp_time: 4294967295, // pion: 2269117121, packet_count: 10, octet_count: 20, @@ -132,13 +126,10 @@ async fn test_sender_interceptor_after_rtp_packets() -> Result<()> { #[tokio::test] async fn test_sender_interceptor_after_rtp_packets_overflow() -> Result<()> { let mt = Arc::new(MockTime::default()); - let mt2 = Arc::clone(&mt); - let time_gen = Arc::new( - move || -> Pin + Send + 'static>> { - let mt3 = Arc::clone(&mt2); - Box::pin(async move { mt3.now().await }) - }, - ); + let time_gen = { + let mt = Arc::clone(&mt); + Arc::new(move || mt.now()) + }; let icpr: Arc = SenderReport::builder() .with_interval(Duration::from_millis(50)) @@ -206,7 +197,7 @@ async fn test_sender_interceptor_after_rtp_packets_overflow() -> Result<()> { .await?; let dt = Utc.ymd(2009, 10, 23).and_hms(0, 0, 0); - mt.set_now(dt.into()).await; + mt.set_now(dt.into()); let pkts = stream.written_rtcp().await.unwrap(); assert_eq!(pkts.len(), 1); @@ -217,7 +208,7 @@ async fn test_sender_interceptor_after_rtp_packets_overflow() -> Result<()> { assert_eq!( &rtcp::sender_report::SenderReport { ssrc: 123456, - ntp_time: unix2ntp(mt.now().await), + ntp_time: unix2ntp(mt.now()), rtp_time: 4294967295, // pion: 2269117121, packet_count: 5, octet_count: 10,