Skip to content

Commit

Permalink
Bug 2089: Allow S3 connection with IAM role instead of credentials
Browse files Browse the repository at this point in the history
https://winscp.net/tracker/2089

Source commit: 724382cc89254fb6b0d441f37bdae43c750fef09
  • Loading branch information
martinprikryl committed Aug 19, 2023
1 parent 62261a2 commit 6f231ac
Show file tree
Hide file tree
Showing 9 changed files with 153 additions and 16 deletions.
7 changes: 7 additions & 0 deletions source/core/Configuration.cpp
Expand Up @@ -121,6 +121,7 @@ void __fastcall TConfiguration::Default()
FParallelDurationThreshold = 10;
FMimeTypes = UnicodeString();
FCertificateStorage = EmptyStr;
FAWSMetadataService = EmptyStr;
FChecksumCommands = EmptyStr;
FDontReloadMoreThanSessions = 1000;
FScriptProgressFileNameLimit = 25;
Expand Down Expand Up @@ -267,6 +268,7 @@ UnicodeString __fastcall TConfiguration::PropertyToKey(const UnicodeString & Pro
KEY(Integer, KeyVersion); \
KEY(Bool, CollectUsage); \
KEY(String, CertificateStorage); \
KEY(String, AWSMetadataService); \
); \
BLOCK(L"Logging", CANCREATE, \
KEYEX(Bool, PermanentLogging, L"Logging"); \
Expand Down Expand Up @@ -1756,6 +1758,11 @@ UnicodeString TConfiguration::GetCertificateStorageExpanded()
return Result;
}
//---------------------------------------------------------------------
void TConfiguration::SetAWSMetadataService(const UnicodeString & value)
{
SET_CONFIG_PROPERTY(AWSMetadataService);
}
//---------------------------------------------------------------------
void __fastcall TConfiguration::SetTryFtpWhenSshFails(bool value)
{
SET_CONFIG_PROPERTY(TryFtpWhenSshFails);
Expand Down
3 changes: 3 additions & 0 deletions source/core/Configuration.h
Expand Up @@ -83,6 +83,7 @@ class TConfiguration : public TObject
int FQueueTransfersLimit;
int FParallelTransferThreshold;
UnicodeString FCertificateStorage;
UnicodeString FAWSMetadataService;
UnicodeString FChecksumCommands;

bool FDisablePasswordStoring;
Expand Down Expand Up @@ -150,6 +151,7 @@ class TConfiguration : public TObject
void __fastcall SetMimeTypes(UnicodeString value);
void SetCertificateStorage(const UnicodeString & value);
UnicodeString GetCertificateStorageExpanded();
void SetAWSMetadataService(const UnicodeString & value);
bool __fastcall GetCollectUsage();
void __fastcall SetCollectUsage(bool value);
bool __fastcall GetIsUnofficial();
Expand Down Expand Up @@ -335,6 +337,7 @@ class TConfiguration : public TObject
__property UnicodeString ExternalIpAddress = { read = FExternalIpAddress, write = SetExternalIpAddress };
__property UnicodeString CertificateStorage = { read = FCertificateStorage, write = SetCertificateStorage };
__property UnicodeString CertificateStorageExpanded = { read = GetCertificateStorageExpanded };
__property UnicodeString AWSMetadataService = { read = FAWSMetadataService, write = SetAWSMetadataService };
__property UnicodeString ChecksumCommands = { read = FChecksumCommands };
__property int LocalPortNumberMin = { read = FLocalPortNumberMin, write = SetLocalPortNumberMin };
__property int LocalPortNumberMax = { read = FLocalPortNumberMax, write = SetLocalPortNumberMax };
Expand Down
2 changes: 2 additions & 0 deletions source/core/Http.cpp
Expand Up @@ -10,6 +10,8 @@
#include "TextsCore.h"
#include <openssl/ssl.h>
//---------------------------------------------------------------------------
const int BasicHttpResponseLimit = 102400;
//---------------------------------------------------------------------------
THttp::THttp()
{
FProxyPort = 0;
Expand Down
2 changes: 2 additions & 0 deletions source/core/Http.h
Expand Up @@ -13,6 +13,8 @@ class THttp;
typedef void __fastcall (__closure * THttpDownloadEvent)(THttp * Sender, __int64 Size, bool & Cancel);
typedef void __fastcall (__closure * THttpErrorEvent)(THttp * Sender, int Status, const UnicodeString & Message);
//---------------------------------------------------------------------------
extern const int BasicHttpResponseLimit;
//---------------------------------------------------------------------------
class THttp
{
public:
Expand Down
118 changes: 113 additions & 5 deletions source/core/S3FileSystem.cpp
Expand Up @@ -19,6 +19,10 @@
#include <ne_request.h>
#include <StrUtils.hpp>
#include <limits>
#include "CoreMain.h"
#include "Http.h"
#include <System.JSON.hpp>
#include <System.DateUtils.hpp>
//---------------------------------------------------------------------------
#pragma package(smart_init)
//---------------------------------------------------------------------------
Expand Down Expand Up @@ -57,6 +61,11 @@ UnicodeString S3ConfigFileName;
TDateTime S3ConfigTimestamp;
std::unique_ptr<TCustomIniFile> S3ConfigFile;
UnicodeString S3Profile;
bool S3SecurityProfileChecked = false;
TDateTime S3CredentialsExpiration;
UnicodeString S3SecurityProfile;
typedef std::map<UnicodeString, UnicodeString> TS3Credentials;
TS3Credentials S3Credentials;
//---------------------------------------------------------------------------
static void NeedS3Config()
{
Expand Down Expand Up @@ -88,6 +97,7 @@ static void NeedS3Config()
{
S3ConfigTimestamp = Timestamp;
// TMemIniFile silently ignores empty paths or non-existing files
AppLog(L"Reading AWS credentials file");
S3ConfigFile.reset(new TMemIniFile(S3ConfigFileName));
}
}
Expand Down Expand Up @@ -124,11 +134,22 @@ TStrings * GetS3Profiles()
return Result.release();
}
//---------------------------------------------------------------------------
UnicodeString GetS3ConfigValue(const UnicodeString & Profile, const UnicodeString & Name, UnicodeString * Source)
UnicodeString ReadUrl(const UnicodeString & Url)
{
std::unique_ptr<THttp> Http(new THttp());
Http->URL = Url;
Http->ResponseLimit = BasicHttpResponseLimit;
Http->Get();
return Http->Response.Trim();
}
//---------------------------------------------------------------------------
UnicodeString GetS3ConfigValue(
const UnicodeString & Profile, const UnicodeString & Name, const UnicodeString & CredentialsName, UnicodeString * Source)
{
UnicodeString Result;
UnicodeString ASource;
TGuard Guard(LibS3Section.get());

try
{
if (Profile.IsEmpty())
Expand Down Expand Up @@ -161,6 +182,92 @@ UnicodeString GetS3ConfigValue(const UnicodeString & Profile, const UnicodeStrin
{
throw ExtException(&E, MainInstructions(LoadStr(S3_CONFIG_ERROR)));
}

if (Result.IsEmpty())
{
if (S3SecurityProfileChecked && (S3CredentialsExpiration != TDateTime()) && (IncHour(S3CredentialsExpiration, -1) < Now()))
{
AppLog(L"AWS security credentials has expired or is close to expiration, will retrieve new");
S3SecurityProfileChecked = false;
}

if (!S3SecurityProfileChecked)
{
S3Credentials.clear();
S3SecurityProfile = EmptyStr;
S3SecurityProfileChecked = true;
S3CredentialsExpiration = TDateTime();
try
{
UnicodeString AWSMetadataService = DefaultStr(Configuration->AWSMetadataService, L"http://169.254.169.254/latest/meta-data/");
UnicodeString SecurityCredentialsUrl = AWSMetadataService + L"iam/security-credentials/";

AppLogFmt(L"Retrieving AWS security credentials from %s", (SecurityCredentialsUrl));
S3SecurityProfile = ReadUrl(SecurityCredentialsUrl);

if (S3SecurityProfile.IsEmpty())
{
AppLog(L"No AWS security credentials role detected");
}
else
{
UnicodeString SecurityProfileUrl = SecurityCredentialsUrl + EncodeUrlString(S3SecurityProfile);
AppLogFmt(L"AWS security credentials role detected: %s, retrieving %s", (S3SecurityProfile, SecurityProfileUrl));
UnicodeString ProfileDataStr = ReadUrl(SecurityProfileUrl);

std::unique_ptr<TJSONValue> ProfileDataValue(TJSONObject::ParseJSONValue(ProfileDataStr));
TJSONObject * ProfileData = dynamic_cast<TJSONObject *>(ProfileDataValue.get());
if (ProfileData == NULL)
{
throw new Exception(FORMAT(L"Unexpected response: %s", (ProfileDataStr.SubString(1, 1000))));
}
TJSONValue * CodeValue = ProfileData->Values[L"Code"];
if (CodeValue == NULL)
{
throw new Exception(L"Missing \"Code\" value");
}
UnicodeString Code = CodeValue->Value();
if (!SameText(Code, L"Success"))
{
throw new Exception(FORMAT(L"Received non-success code: %s", (Code)));
}
TJSONValue * ExpirationValue = ProfileData->Values[L"Expiration"];
if (ExpirationValue == NULL)
{
throw new Exception(L"Missing \"Expiration\" value");
}
UnicodeString ExpirationStr = ExpirationValue->Value();
S3CredentialsExpiration = ISO8601ToDate(ExpirationStr, false);
AppLogFmt(L"Credentials expiration: %s", (StandardTimestamp(S3CredentialsExpiration)));

std::unique_ptr<TJSONPairEnumerator> Enumerator(ProfileData->GetEnumerator());
UnicodeString Names;
while (Enumerator->MoveNext())
{
TJSONPair * Pair = Enumerator->Current;
UnicodeString Name = Pair->JsonString->Value();
S3Credentials.insert(std::make_pair(Name, Pair->JsonValue->Value()));
AddToList(Names, Name, L", ");
}
AppLogFmt(L"Response contains following values: %s", (Names));
}
}
catch (Exception & E)
{
UnicodeString Message;
ExceptionMessage(&E, Message);
AppLogFmt(L"Error retrieving AWS security credentials role: %s", (Message));
}
}

TS3Credentials::const_iterator I = S3Credentials.find(CredentialsName);
if (I != S3Credentials.end())
{
Result = I->second;
ASource = FORMAT(L"meta-data/%s", (S3SecurityProfile));
}
}

if (Source != NULL)
{
*Source = ASource;
Expand All @@ -170,17 +277,17 @@ UnicodeString GetS3ConfigValue(const UnicodeString & Profile, const UnicodeStrin
//---------------------------------------------------------------------------
UnicodeString S3EnvUserName(const UnicodeString & Profile, UnicodeString * Source)
{
return GetS3ConfigValue(Profile, AWS_ACCESS_KEY_ID, Source);
return GetS3ConfigValue(Profile, AWS_ACCESS_KEY_ID, L"AccessKeyId", Source);
}
//---------------------------------------------------------------------------
UnicodeString S3EnvPassword(const UnicodeString & Profile, UnicodeString * Source)
{
return GetS3ConfigValue(Profile, AWS_SECRET_ACCESS_KEY, Source);
return GetS3ConfigValue(Profile, AWS_SECRET_ACCESS_KEY, L"SecretAccessKey", Source);
}
//---------------------------------------------------------------------------
UnicodeString S3EnvSessionToken(const UnicodeString & Profile, UnicodeString * Source)
{
return GetS3ConfigValue(Profile, AWS_SESSION_TOKEN, Source);
return GetS3ConfigValue(Profile, AWS_SESSION_TOKEN, L"Token", Source);
}
//---------------------------------------------------------------------------
//---------------------------------------------------------------------------
Expand Down Expand Up @@ -976,7 +1083,8 @@ S3Status TS3FileSystem::LibS3ListBucketCallback(
int Sec = 0;
// The libs3's parseIso8601Time uses mktime, so returns a local time, which we would have to complicatedly restore,
// Doing own parting instead as it's easier.
// Keep is sync with WebDAV
// Might be replaced with ISO8601ToDate.
// Keep is sync with WebDAV.
int Filled =
sscanf(Content->lastModifiedStr, ISO8601_FORMAT, &Year, &Month, &Day, &Hour, &Min, &Sec);
if (Filled == 6)
Expand Down
4 changes: 4 additions & 0 deletions source/core/SessionInfo.cpp
Expand Up @@ -1426,6 +1426,10 @@ void __fastcall TSessionLog::DoAddStartupInfo(TSessionData * Data)
{
ADF(L"S3: Session token: %s", (Data->S3SessionToken));
}
if (Data->S3CredentialsEnv)
{
ADF(L"S3: Credentials from AWS environment: %s", (DefaultStr(Data->S3Profile, L"General")));
}
}
if (FtpsOn)
{
Expand Down
30 changes: 20 additions & 10 deletions source/forms/Login.cpp
Expand Up @@ -78,6 +78,7 @@ __fastcall TLoginDialog::TLoginDialog(TComponent* AOwner)
FLinkedForm = NULL;
FRestoring = false;
FPrevPos = TPoint(std::numeric_limits<LONG>::min(), std::numeric_limits<LONG>::min());
FWasEverS3 = false;

// we need to make sure that window procedure is set asap
// (so that CM_SHOWINGCHANGED handling is applied)
Expand Down Expand Up @@ -674,6 +675,10 @@ void __fastcall TLoginDialog::UpdateControls()
bool FtpProtocol = (FSProtocol == fsFTP);
bool WebDavProtocol = (FSProtocol == fsWebDAV);
bool S3Protocol = (FSProtocol == fsS3);
if (S3Protocol)
{
FWasEverS3 = true;
}

// session
FtpsCombo->Visible = Editable && FtpProtocol;
Expand Down Expand Up @@ -2235,17 +2240,22 @@ void __fastcall TLoginDialog::TransferProtocolComboChange(TObject * Sender)
{
HostNameEdit->Clear();
}
if (UserNameEdit->Text == S3EnvUserName(S3Profile))
{
UserNameEdit->Clear();
}
if (PasswordEdit->Text == S3EnvPassword(S3Profile))
{
PasswordEdit->Clear();
}
if ((FSessionData != NULL) && (FSessionData->S3SessionToken == S3EnvSessionToken(S3Profile)))
// Optimization to avoid querying AWS metadata service.
// Smarter would be to tell S3EnvXXX functions not to do expensive queries.
if (FWasEverS3)
{
FSessionData->S3SessionToken = UnicodeString();
if (UserNameEdit->Text == S3EnvUserName(S3Profile))
{
UserNameEdit->Clear();
}
if (PasswordEdit->Text == S3EnvPassword(S3Profile))
{
PasswordEdit->Clear();
}
if ((FSessionData != NULL) && (FSessionData->S3SessionToken == S3EnvSessionToken(S3Profile)))
{
FSessionData->S3SessionToken = UnicodeString();
}
}
}
catch (...)
Expand Down
1 change: 1 addition & 0 deletions source/forms/Login.h
Expand Up @@ -326,6 +326,7 @@ class TLoginDialog : public TForm
UnicodeString FPasswordLabel;
int FFixedSessionImages;
bool FRestoring;
bool FWasEverS3;

void __fastcall LoadSession(TSessionData * SessionData);
void __fastcall LoadContents();
Expand Down
2 changes: 1 addition & 1 deletion source/windows/Setup.cpp
Expand Up @@ -930,7 +930,7 @@ static bool __fastcall DoQueryUpdates(TUpdatesConfiguration & Updates, bool Coll
AppLogFmt(L"Updates check URL: %s", (URL));
CheckForUpdatesHTTP->URL = URL;
// sanity check
CheckForUpdatesHTTP->ResponseLimit = 102400;
CheckForUpdatesHTTP->ResponseLimit = BasicHttpResponseLimit;
try
{
if (CollectUsage)
Expand Down

0 comments on commit 6f231ac

Please sign in to comment.