/
azure_postgresql_enforce_ssl_connection_enable.py
111 lines (91 loc) · 4 KB
/
azure_postgresql_enforce_ssl_connection_enable.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import json
import os
import sys
import logging
import time
from azure.mgmt.rdbms.postgresql import PostgreSQLManagementClient
from azure.identity import ClientSecretCredential
from azure.mgmt.rdbms.postgresql.models import ServerUpdateParameters
logging.basicConfig(level=logging.INFO)
class EnableSslEnforcement(object):
def parse(self, payload):
"""Parse payload received from Remediation Service.
:param payload: JSON string containing parameters received from the remediation service.
:type payload: str.
:returns: Dictionary of parsed parameters
:rtype: dict
:raises: KeyError, JSONDecodeError
"""
remediation_entry = json.loads(payload)
object_id = remediation_entry["notificationInfo"]["FindingInfo"]["ObjectId"]
region = remediation_entry["notificationInfo"]["FindingInfo"]["Region"]
object_chain = remediation_entry["notificationInfo"]["FindingInfo"][
"ObjectChain"
]
object_chain_dict = json.loads(object_chain)
subscription_id = object_chain_dict["cloudAccountId"]
properties = object_chain_dict["properties"]
resource_group_name = ""
for property in properties:
if property["name"] == "ResourceGroup" and property["type"] == "string":
resource_group_name = property["stringV"]
break
logging.info("parsed params")
logging.info(f" resource_group_name: {resource_group_name}")
logging.info(f" account_name: {object_id}")
logging.info(f" subscription_id: {subscription_id}")
logging.info(f" region: {region}")
return {
"resource_group_name": resource_group_name,
"postgre_server_name": object_id,
"subscription_id": subscription_id,
"region": region,
}
def remediate(self, client, resource_group_name, postgre_server_name):
"""Enable Enforce SSL connection for PostgreSQL Database Server
:param client: Instance of the Azure PostgreSQLManagementClient.
:param resource_group_name: The name of the resource group.
:param postgre_server_name: The name of the PostgreSQL Server.
:type resource_group_name: str.
:type postgre_server_name: str.
:returns: Integer signaling success or failure
:rtype: int
:raises: msrestazure.azure_exceptions.CloudError
"""
logging.info("Enabling Enforce SSL connection for PostgreSQL Database Server")
try:
logging.info(" executing client.servers.begin_update")
logging.info(f" resource_group_name={resource_group_name}")
logging.info(f" server_name={postgre_server_name}")
poller = client.servers.begin_update(
resource_group_name=resource_group_name,
server_name=postgre_server_name,
parameters=ServerUpdateParameters(ssl_enforcement="Enabled"),
)
while not poller.done():
time.sleep(5)
status = poller.status()
logging.info(f"The remediation job status: {status}")
poller.result()
except Exception as e:
logging.error(f"{str(e)}")
raise
return 0
def run(self, args):
"""Run the remediation job.
:param args: List of arguments provided to the job.
:type args: list.
:returns: int
"""
params = self.parse(args[1])
credential = ClientSecretCredential(
client_id=os.environ.get("AZURE_CLIENT_ID"),
client_secret=os.environ.get("AZURE_CLIENT_SECRET"),
tenant_id=os.environ.get("AZURE_TENANT_ID"),
)
client = PostgreSQLManagementClient(credential, params["subscription_id"])
return self.remediate(
client, params["resource_group_name"], params["postgre_server_name"],
)
if __name__ == "__main__":
sys.exit(EnableSslEnforcement().run(sys.argv))