Skip to content

Commit

Permalink
#2550 #2567 support parallel downloading of multiple files
Browse files Browse the repository at this point in the history
CDP may send "in progress" events about multiple files simultaneously.
We need to track of all them, and pick the right file using given FileFilter.
  • Loading branch information
asolntsev committed Feb 3, 2024
1 parent e9e44cb commit a892fc8
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ void waitForNewFiles(Driver driver, FileFilter fileFilter, DownloadsFolder folde

protected void failFastIfNoChanges(Driver driver, DownloadsFolder folder, FileFilter filter,
long start, long timeout, long incrementTimeout) {
long lastFileUpdate = folder.lastModificationTime().orElse(-1L);
long lastFileUpdate = folder.lastModificationTime().orElse(start);
long now = currentTimeMillis();
long filesHasNotBeenUpdatedForMs = filesHasNotBeenUpdatedForMs(start, now, lastFileUpdate);
if (filesHasNotBeenUpdatedForMs > incrementTimeout) {
Expand Down
113 changes: 77 additions & 36 deletions src/main/java/com/codeborne/selenide/impl/DownloadFileToFolderCdp.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.codeborne.selenide.impl;

import com.codeborne.selenide.DownloadsFolder;
import com.codeborne.selenide.Driver;
import com.codeborne.selenide.Stopwatch;
import com.codeborne.selenide.ex.FileNotDownloadedError;
Expand All @@ -21,14 +22,15 @@
import javax.annotation.ParametersAreNonnullByDefault;
import java.io.File;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;

import static com.codeborne.selenide.impl.FileHelper.moveFile;
import static java.lang.System.currentTimeMillis;
import static java.util.Collections.emptyMap;
import static java.util.Objects.requireNonNull;
import static org.openqa.selenium.devtools.v120.browser.Browser.downloadProgress;
import static org.openqa.selenium.devtools.v120.browser.Browser.downloadWillBegin;

Expand Down Expand Up @@ -56,26 +58,23 @@ public File download(WebElementSource anyClickableElement,

Driver driver = anyClickableElement.driver();
DevTools devTools = initDevTools(driver);

AtomicBoolean downloadComplete = new AtomicBoolean(false);
AtomicReference<String> fileName = new AtomicReference<>();
AtomicLong lastModifiedAt = new AtomicLong(currentTimeMillis());
DownloadsFolder downloadsFolder = requireNonNull(driver.browserDownloadsFolder(), "Webdriver downloads folder is not configured");
CdpDownloads downloads = new CdpDownloads(downloadsFolder, new ConcurrentHashMap<>(1));

// Init download behaviour and listeners
prepareDownloadWithCdp(driver, devTools, fileName, downloadComplete, lastModifiedAt, timeout);
prepareDownloadWithCdp(driver, devTools, downloads, timeout);

// Perform action an element that begins download process
action.perform(anyClickableElement.driver(), clickable);

try {
// Wait until download
File file = waitUntilDownloadsCompleted(anyClickableElement.driver(), fileFilter,
timeout, incrementTimeout, lastModifiedAt, downloadComplete, fileName);
File file = waitUntilDownloadsCompleted(anyClickableElement.driver(), fileFilter, timeout, incrementTimeout, downloads);

//
if (!fileFilter.match(new DownloadedFile(file, emptyMap()))) {
String message = String.format("Failed to download file in %d ms.%s;%n actually downloaded: %s",
timeout, fileFilter.description(), file.getAbsolutePath());
String message = String.format("Failed to download file%s in %d ms.%s;%n actually downloaded: %s",
fileFilter.description(), timeout, fileFilter.description(), file.getAbsolutePath());
throw new FileNotDownloadedError(driver, message, timeout);
}

Expand All @@ -97,24 +96,24 @@ protected File archiveFile(Driver driver, File downloadedFile) {
}

private File waitUntilDownloadsCompleted(Driver driver, FileFilter fileFilter,
long timeout, long incrementTimeout,
AtomicLong lastModifiedAt, AtomicBoolean downloadComplete,
AtomicReference<String> fileName) {
long timeout, long incrementTimeout, CdpDownloads downloads) {
long pollingInterval = Math.max(driver.config().pollingInterval(), 100);
long downloadStartedAt = currentTimeMillis();
Stopwatch stopwatch = new Stopwatch(timeout);
do {
if (downloadComplete.get()) {
log.debug("File {} download is complete after {} ms.", fileName, stopwatch.getElapsedTimeMs());
return new File(driver.browserDownloadsFolder().toString(), fileName.get());
Optional<CdpDownload> downloadedFile = downloads.find(fileFilter);
if (downloadedFile.isPresent()) {
log.debug("File {} download is complete after {} ms.", downloadedFile.get().fileName, stopwatch.getElapsedTimeMs());
return downloadedFile.get().file();
}
else {
failFastIfNoChanges(driver, lastModifiedAt.get(), fileFilter, timeout, incrementTimeout);
failFastIfNoChanges(driver, downloads, fileFilter, downloadStartedAt, timeout, incrementTimeout);
}
stopwatch.sleep(pollingInterval);
}
while (!stopwatch.isTimeoutReached());

String message = "Failed to download file in %d ms".formatted(timeout);
String message = "Failed to download file%s in %d ms".formatted(fileFilter.description(), timeout);
throw new FileNotDownloadedError(driver, message, timeout);
}

Expand All @@ -133,31 +132,73 @@ private DevTools initDevTools(Driver driver) {
}

private void prepareDownloadWithCdp(Driver driver, DevTools devTools,
AtomicReference<String> fileName, AtomicBoolean downloadComplete, AtomicLong lastModifiedAt,
CdpDownloads downloads,
long timeout) {
devTools.send(Browser.setDownloadBehavior(
Browser.SetDownloadBehaviorBehavior.ALLOW,
Optional.empty(),
Optional.of(driver.browserDownloadsFolder().toString()),
Optional.of(downloads.folder.toString()),
Optional.of(true)));

devTools.clearListeners();
devTools.addListener(downloadWillBegin(), new DownloadWillBeginListener(id(), fileName, lastModifiedAt));
devTools.addListener(downloadProgress(), new DownloadProgressListener(id(), driver, downloadComplete, lastModifiedAt, timeout));
devTools.addListener(downloadWillBegin(), new DownloadWillBeginListener(id(), downloads));
devTools.addListener(downloadProgress(), new DownloadProgressListener(id(), driver, downloads, timeout));
}

private record CdpDownloads(
DownloadsFolder folder,
ConcurrentMap<String, CdpDownload> downloads
) {
private Optional<CdpDownload> find(FileFilter fileFilter) {
return downloads.values().stream()
.filter(download -> download.completed)
.filter(download -> fileFilter.match(download.file()))
.findAny();
}

private Optional<Long> lastModificationTime() {
return downloads.values().stream().map(download -> download.lastModifiedAt).max(Long::compare);
}

private void addFile(String guid, String fileName) {
downloads.put(guid, new CdpDownload(folder, fileName));
}

public void inProgress(String guid) {
downloads.get(guid).lastModifiedAt = currentTimeMillis();
}

public void finish(String guid) {
downloads.get(guid).completed = true;
}
}

private static class CdpDownload {
private final DownloadsFolder folder;
private final String fileName;
private long lastModifiedAt = currentTimeMillis();
private boolean completed;

private CdpDownload(DownloadsFolder folder, String fileName) {
this.folder = folder;
this.fileName = fileName;
}

private File file() {
return new File(folder.toString(), fileName);
}
}

private static long id() {
return SEQUENCE.incrementAndGet();
}

private record DownloadWillBeginListener(long id, AtomicReference<String> fileName, AtomicLong lastModifiedAt)
implements Consumer<DownloadWillBegin> {
private record DownloadWillBeginListener(long id, CdpDownloads downloads) implements Consumer<DownloadWillBegin> {
@Override
public void accept(DownloadWillBegin e) {
log.debug("[{}] Download will begin with suggested file name \"{}\" (url: \"{}\", frameId: {}, guid: {})",
id, e.getSuggestedFilename(), e.getUrl(), e.getFrameId(), e.getGuid());
fileName.set(e.getSuggestedFilename());
lastModifiedAt.set(currentTimeMillis());
downloads.addFile(e.getGuid(), e.getSuggestedFilename());
}

@Override
Expand All @@ -166,8 +207,7 @@ public String toString() {
}
}

private record DownloadProgressListener(long id, Driver driver,
AtomicBoolean downloadComplete, AtomicLong lastModifiedAt, long timeout)
private record DownloadProgressListener(long id, Driver driver, CdpDownloads downloads, long timeout)
implements Consumer<DownloadProgress> {
@Override
public void accept(DownloadProgress e) {
Expand All @@ -180,8 +220,8 @@ public void accept(DownloadProgress e) {
e.getState(), e.getReceivedBytes(), e.getTotalBytes(), e.getGuid());
throw new FileNotDownloadedError(driver, message, timeout);
}
case COMPLETED -> downloadComplete.set(true);
case INPROGRESS -> lastModifiedAt.set(currentTimeMillis());
case COMPLETED -> downloads.finish(e.getGuid());
case INPROGRESS -> downloads.inProgress(e.getGuid());
}
}

Expand All @@ -191,15 +231,16 @@ public String toString() {
}
}

protected void failFastIfNoChanges(Driver driver, long lastModifiedAt, FileFilter filter,
long timeout, long incrementTimeout) {
protected void failFastIfNoChanges(Driver driver, CdpDownloads downloads, FileFilter filter,
long downloadStartedAt, long timeout, long incrementTimeout) {
long now = currentTimeMillis();
long lastModifiedAt = downloads.lastModificationTime().orElse(downloadStartedAt);
long filesHasNotBeenUpdatedForMs = now - lastModifiedAt;
if (filesHasNotBeenUpdatedForMs > incrementTimeout) {
String message = String.format(
"Failed to download file%s in %d ms: file hasn't been modified for %s ms. " +
"(lastFileUpdate: %s, now: %s, incrementTimeout: %s)",
filter.description(), timeout, filesHasNotBeenUpdatedForMs,
"Failed to download file%s in %d ms: files in %s haven't been modified for %s ms. " +
"(lastUpdate: %s, now: %s, incrementTimeout: %s)",
filter.description(), timeout, downloads.folder, filesHasNotBeenUpdatedForMs,
lastModifiedAt, now, incrementTimeout);
throw new FileNotDownloadedError(driver, message, timeout);
}
Expand Down
13 changes: 13 additions & 0 deletions statics/src/test/java/integration/FileDownloadToFolderTest.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package integration;

import com.codeborne.selenide.Configuration;
import com.codeborne.selenide.impl.FileContent;
import org.apache.commons.lang3.SystemUtils;
import com.codeborne.selenide.ex.FileNotDownloadedError;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -278,4 +281,14 @@ public void cannotDownloadUsingProxy_ifBrowserIsOpenedWithoutProxy() {
.hasMessageStartingWith("Cannot download file: proxy server is not enabled. Setup proxyEnabled");
}

@ParameterizedTest
@ValueSource(strings = {"empty.html", "hello_world.txt", "download.html"})
void downloadMultipleFiles(String fileName) {
openFile("downloadMultipleFiles.html");

File text = $("#multiple-downloads").download(withName(fileName));

assertThat(text.getName()).isEqualTo(fileName);
assertThat(text.length()).isEqualTo(new FileContent(fileName).content().length());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

import com.codeborne.selenide.Configuration;
import com.codeborne.selenide.ex.FileNotDownloadedError;
import com.codeborne.selenide.impl.FileContent;
import org.apache.commons.lang3.SystemUtils;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -95,15 +98,15 @@ void downloadMissingFile() {
timeout = 888;
assertThatThrownBy(() -> $(byText("Download missing file")).download(withExtension("txt")))
.isInstanceOf(FileNotDownloadedError.class)
.hasMessageStartingWith("Failed to download file in 888 ms");
.hasMessageStartingWith("Failed to download file with extension \"txt\" in 888 ms");
}

@Test
void downloadMissingFileWithExtension() {
timeout = 888;
assertThatThrownBy(() -> $(byText("Download me")).download(withExtension("pdf")))
.isInstanceOf(FileNotDownloadedError.class)
.hasMessageStartingWith("Failed to download file in 888 ms. with extension \"pdf\"");
.hasMessageStartingWith("Failed to download file with extension \"pdf\" in 888 ms");
}

@Test
Expand Down Expand Up @@ -231,10 +234,9 @@ public void canSpecifyTimeoutForFileIncrement_downloadNotEvenStarted() {
.download(shortIncrementTimeout))
.isInstanceOf(FileNotDownloadedError.class)
.hasMessageStartingWith("""
Failed to download file with name "hello_world.txt" in 10000 ms
""".trim())
Failed to download file with name "hello_world.txt" in 10000 ms""")
.hasMessageMatching(Pattern.compile("""
(?s).+: file hasn't been modified for \\d+ ms\\. +\\(lastFileUpdate: -?\\d+, now: \\d+, incrementTimeout: 201\\).*
(?s).+: files in .+ haven't been modified for \\d+ ms\\. +\\(lastUpdate: -?\\d+, now: \\d+, incrementTimeout: 201\\).*
""".trim(), DOTALL));

closeWebDriver();
Expand Down Expand Up @@ -272,4 +274,14 @@ public void cannotDownloadUsingProxy_ifBrowserIsOpenedWithoutProxy() {
.hasMessageStartingWith("Cannot download file: proxy server is not enabled. Setup proxyEnabled");
}

@ParameterizedTest
@ValueSource(strings = {"empty.html", "hello_world.txt", "download.html"})
void downloadMultipleFiles(String fileName) {
openFile("downloadMultipleFiles.html");

File text = $("#multiple-downloads").download(withName(fileName));

assertThat(text.getName()).isEqualTo(fileName);
assertThat(text.length()).isEqualTo(new FileContent(fileName).content().length());
}
}

0 comments on commit a892fc8

Please sign in to comment.