Skip to content

Commit

Permalink
Reworking download_iterator to add concurrent download limit (#190)
Browse files Browse the repository at this point in the history
  • Loading branch information
mbrossard committed Jan 28, 2023
1 parent a6291b7 commit 1c369fb
Showing 1 changed file with 128 additions and 68 deletions.
196 changes: 128 additions & 68 deletions rust/cmsis-pack/src/update/download.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use futures::prelude::*;
use futures::stream::futures_unordered::FuturesUnordered;
use reqwest::{redirect, Url};
use reqwest::{Client, ClientBuilder, Response};
use tokio::task::JoinHandle;
use tokio::time::{sleep, Duration};

use crate::pack_index::{PdscRef, Pidx, Vidx};
use crate::pdsc::Package;
Expand All @@ -16,6 +18,9 @@ use bytes::Bytes;
use futures::StreamExt;
use std::collections::HashMap;

const CONCURRENCY : usize = 32;
const HOST_LIMIT : usize = 6;

fn parse_vidx(body: Bytes) -> Result<Vidx, Error> {
let string = String::from_utf8_lossy(body.as_ref());
Vidx::from_string(string.borrow())
Expand Down Expand Up @@ -104,6 +109,42 @@ impl<'a> IntoDownload for &'a Package {
}
}


async fn save_response(response: Response, dest: PathBuf) -> Result<(usize, PathBuf), Error> {
let temp = dest.with_extension("part");
let file = OpenOptions::new().write(true).create(true).open(&temp);

let mut file = match file {
Err(err) => return Err(anyhow!(err.to_string())),
Ok(f) => f,
};

let mut fsize: usize = 0;
let mut stream = response.bytes_stream();
while let Some(chunk) = stream.next().await {
match chunk {
Ok(bytes) => {
fsize += bytes.len();

if let Err(err) = file.write_all(bytes.as_ref()) {
let _ = std::fs::remove_file(temp);
return Err(anyhow!(err.to_string()));
}
}
Err(err) => {
let _ = std::fs::remove_file(temp);
return Err(anyhow!(err.to_string()));
}
}
}
if let Err(err) = rename(&temp, &dest) {
let _ = std::fs::remove_file(temp);
return Err(anyhow!(err.to_string()));
}
Ok((fsize, dest))
}


pub trait DownloadProgress: Send {
fn size(&self, files: usize);
fn progress(&self, bytes: usize);
Expand Down Expand Up @@ -145,91 +186,110 @@ where
})
}

async fn save_response(&'a self, response: Response, dest: PathBuf) -> Result<PathBuf, Error> {
let temp = dest.with_extension("part");
let file = OpenOptions::new().write(true).create(true).open(&temp);

let mut file = match file {
Err(err) => return Err(anyhow!(err.to_string())),
Ok(f) => f,
};

let mut stream = response.bytes_stream();
while let Some(chunk) = stream.next().await {
match chunk {
Ok(bytes) => {
self.prog.progress(bytes.len());

if let Err(err) = file.write_all(bytes.as_ref()) {
std::fs::remove_file(temp);
return Err(anyhow!(err.to_string()));
}
}
Err(err) => {
std::fs::remove_file(temp);
return Err(anyhow!(err.to_string()));
}
}
}
if let Err(err) = rename(&temp, &dest) {
std::fs::remove_file(temp);
return Err(anyhow!(err.to_string()));
}
Ok(dest)
}

async fn download_file(&'a self, source: Url, dest: PathBuf) -> Result<PathBuf, Error> {
if dest.exists() {
return Ok(dest);
}
dest.parent().map(create_dir_all);
let res = self.client.get(source).send().await;

match res {
Ok(r) => {
let rc = r.status().as_u16();
if rc >= 400 {
Err(anyhow!(format!("Response code in invalid range: {}", rc).to_string()))
} else {
self.save_response(r, dest).await
}
},
Err(err) => Err(anyhow!(err.to_string())),
}
}

pub async fn download_iterator<I>(&'a self, iter: I) -> Vec<PathBuf>
where
I: IntoIterator + 'a,
<I as IntoIterator>::Item: IntoDownload,
{
let to_dl: Vec<(Url, PathBuf)> = iter
let mut to_dl: Vec<(Url, String, PathBuf)> = iter
.into_iter()
.filter_map(|i| {
if let Ok(uri) = i.into_uri() {
Some((uri, i.into_fd(self.config)))
let c = uri.clone();
if let Some(host) = c.host_str() {
Some((uri, host.to_string(), i.into_fd(self.config)))
} else {
None
}
} else {
None
}
})
.collect();
self.prog.size(to_dl.len());

let v = futures::stream::iter(to_dl.into_iter().map(|from| async move {
let r = self.download_file(from.0.clone(), from.1.clone()).await;
self.prog.complete();
match r {
Ok(p) => Some(p),
Err(e) => {
log::error!("download of {:?} failed: {}", from.0.clone(), e);
None
let mut hosts: HashMap<String, usize> = HashMap::new();
let mut results : Vec<PathBuf> = vec![];
let mut started : usize = 0;
let mut handles: Vec<JoinHandle<(String, usize, Option<PathBuf>)>> = vec![];

while !to_dl.is_empty() || !handles.is_empty() {
let mut wait_list: Vec<(Url, String, PathBuf)> = vec![];
let mut next: Vec<JoinHandle<(String, usize, Option<PathBuf>)>> = vec![];

while let Some(handle) = handles.pop() {
if handle.is_finished() {
let r = handle.await.unwrap();
*hosts.entry(r.0).or_insert(1) -= 1;
started -= 1;
self.prog.progress(r.1);
self.prog.complete();
if let Some(path) = r.2 {
results.push(path);
}
} else {
next.push(handle);
}
}
}))
.buffer_unordered(32)
.collect::<Vec<Option<PathBuf>>>()
.await;
v.into_iter().filter_map(|x| x).collect::<Vec<PathBuf>>()

while ! to_dl.is_empty() && started < CONCURRENCY {
let from = to_dl.pop().unwrap();
let host = from.1.clone();
let entry = hosts.entry(host).or_insert(0);
if *entry >= HOST_LIMIT {
wait_list.push(from);
} else {
let source = from.0.clone();
let host = from.1.clone();
let dest = from.2.clone();
if dest.exists() {
results.push(dest);
} else {
let client = self.client.clone();
let handle: JoinHandle<(String, usize, Option<PathBuf>)> = tokio::spawn(async move {
dest.parent().map(create_dir_all);
let res = client.get(source.clone()).send().await;
let res: Result<(usize, PathBuf), Error> = match res {
Ok(r) => {
let rc = r.status().as_u16();
if rc >= 400 {
Err(anyhow!(format!("Response code in invalid range: {}", rc).to_string()))
} else {
save_response(r, dest).await
}
},
Err(err) => {
Err(anyhow!(err.to_string()))
},
};
match res {
Ok(r) => {
(host, r.0, Some(r.1))
},
Err(err) => {
log::error!("download of {:?} failed: {}", source, err);
(host, 0, None)
}
}
});
handles.push(handle);
started += 1;
*entry += 1;
}
}
}

for w in wait_list {
to_dl.push(w);
}

for w in next {
handles.push(w);
}
sleep(Duration::from_millis(100)).await;
}

results
}

pub(crate) async fn update_vidx<I>(&'a self, list: I) -> Result<Vec<PathBuf>, Error>
Expand Down

0 comments on commit 1c369fb

Please sign in to comment.