diff --git a/src/window/win.cc b/src/window/win.cc index 11146a025..27c8a91cc 100644 --- a/src/window/win.cc +++ b/src/window/win.cc @@ -4,11 +4,13 @@ #include #include #include -#include "window.hh" #include "WebView2.h" #include "WebView2EnvironmentOptions.h" +#include "window.hh" +#include "../core/types.hh" + #pragma comment(lib, "Shlwapi.lib") #pragma comment(lib, "urlmon.lib") @@ -749,38 +751,87 @@ namespace SSC { // UNREACHABLE - cannot continue } - const int MAX_ALLOWED_SCHEME_ORIGINS = 64; - int allowedSchemeOriginsCount = 4; - const WCHAR* allowedSchemeOrigins[MAX_ALLOWED_SCHEME_ORIGINS] = { - L"about://*", - L"https://*", - L"file://*", - L"socket://*" + static const int MAX_ALLOWED_SCHEME_ORIGINS = 64; + static const int MAX_CUSTOM_SCHEME_REGISTRATIONS = 64; + + struct SchemeRegistration { + String scheme; + }; + + ICoreWebView2CustomSchemeRegistration* registrations[MAX_CUSTOM_SCHEME_REGISTRATIONS] = {}; + Vector schemeRegistrations; + + schemeRegistrations.push_back({ "ipc" }); + schemeRegistrations.push_back({ "socket" }); + schemeRegistrations.push_back({ "node" }); + schemeRegistrations.push_back({ "npm" }); + + for (const auto& entry : split(opts.userConfig["webview_protocol-handlers"], " ")) { + const auto scheme = replace(trim(entry), ":", ""); + if (app.core->protocolHandlers.registerHandler(scheme)) { + schemeRegistrations.push_back({ scheme }); + } + } + + for (const auto& entry : opts.userConfig) { + const auto& key = entry.first; + if (key.starts_with("webview_protocol-handlers_")) { + const auto scheme = replace(replace(trim(key), "webview_protocol-handlers_", ""), ":", "");; + const auto data = entry.second; + if (app.core->protocolHandlers.registerHandler(scheme, { data })) { + schemeRegistrations.push_back({ scheme }); + } + } + } + + Set origins; + Set protocols = { + "about", + "https", + "socket", + "npm", + "node" }; static const auto devHost = SSC::getDevHost(); + const WCHAR* allowedOrigins[MAX_ALLOWED_SCHEME_ORIGINS] = {} + int allowedOriginsCount = 0; + int registrationsCount = 0; + if (devHost.starts_with("http:")) { - allowedSchemeOrigins[allowedSchemeOriginsCount++] = convertStringToWString(devHost).c_str(); + allowedOrigins[allowedOriginsCount] = convertStringToWString(devHost).c_str(); } - auto ipcSchemeRegistration = Microsoft::WRL::Make(L"ipc"); - ipcSchemeRegistration->put_HasAuthorityComponent(TRUE); - ipcSchemeRegistration->put_TreatAsSecure(TRUE); - ipcSchemeRegistration->SetAllowedOrigins(allowedSchemeOriginsCount, allowedSchemeOrigins); - - auto socketSchemeRegistration = Microsoft::WRL::Make(L"socket"); - socketSchemeRegistration->put_HasAuthorityComponent(TRUE); - socketSchemeRegistration->put_TreatAsSecure(TRUE); - socketSchemeRegistration->SetAllowedOrigins(allowedSchemeOriginsCount, allowedSchemeOrigins); - - // If someone can figure out how to allocate this so we can do it in a loop that'd be great, but even Ms is doing it like this: - // https://learn.microsoft.com/en-us/microsoft-edge/webview2/reference/win32/icorewebview2environmentoptions4?view=webview2-1.0.1587.40 - ICoreWebView2CustomSchemeRegistration* registrations[2] = { - ipcSchemeRegistration.Get(), - socketSchemeRegistration.Get() - }; + for (const auto& schemeRegistration : schemeRegistrations) { + protocols.insert(schemeRegistration.scheme); + } + + for (const auto& protocol : protocols) { + if (origins.size() == MAX_ALLOWED_SCHEME_ORIGINS) { + break; + } + + const auto origin = protocol + "://*"; + origins.insert(origin); + allowedOrigins[allowedOriginsCount++] = convertStringToWString(origin).c_str(); + } + + for (const auto& schemeRegistration : schemeRegistrations) { + auto registration = Microsoft::WRL::Make( + convertStringToWString(schemeRegistration.scheme).c_str() + ); + + registration->put_HasAuthorityComponent(TRUE); + registration->put_TreatAsSecure(TRUE); + registration->SetAllowedOrigins(origins.size(), allowedOrigins); - options4->SetCustomSchemeRegistrations(2, static_cast(registrations)); + registrations[registrationsCount++] = registrations.Get(); + } + + options4->SetCustomSchemeRegistrations( + registrationsCount, + static_cast(registrations) + ); auto init = [&, opts]() -> HRESULT { return CreateCoreWebView2EnvironmentWithOptions(