In [0]:
dbutils.widgets.dropdown("JAR Module", "NA" ,["NA", "ingestion", "aggregation"])
dbutils.widgets.text("PAT", "")
dbutils.widgets.text("Pipeline Name", "")
dbutils.widgets.text("Activity Name", "")
dbutils.widgets.text("Branch Name", "")
dbutils.widgets.dropdown("Update Mode", "JAR and Check LS", ["JAR Only", "JAR and Check LS", "LS Only", "JAR and LS"])
dbutils.widgets.dropdown("Linked Service Runtime[To Check]", "13.3" ,["13.3", "14.3"])
dbutils.widgets.text("New Linked Service", "")
dbutils.widgets.text("Annotation", "")

In [0]:
import requests
from requests.auth import HTTPBasicAuth
import json
import re
import hashlib

# AuthenticationHandler handles authentication and base URL setup
class AuthenticationHandler:
    def __init__(self, pat, organization, project_name, repo_name, feed_name):
        self.pat = pat
        self.organization = organization
        self.project_name = project_name
        self.repo_name = repo_name
        self.feed_name = feed_name
        self.base_url = f'https://dev.azure.com/{self.organization}/{self.project_name}/_apis/git/repositories/{self.repo_name}'
        self.feeds_base_url = f'https://feeds.dev.azure.com/{self.organization}/{self.project_name}/_apis/packaging/feeds/{self.feed_name}'
    
    def get_auth(self):
        return HTTPBasicAuth('', self.pat)


# PackageManager handles fetching package details, versions, and jar files
class PackageManager:
    def __init__(self, auth_handler, module):
        self.auth_handler = auth_handler
        self.module = module
    
    def get_package_id(self):
        url = f'{self.auth_handler.feeds_base_url}/packages'
        response = requests.get(url, auth=self.auth_handler.get_auth())
        if response.status_code == 200:
            packages = response.json()['value']
            for package_info in packages:
                if package_info['normalizedName'] == f'com.organization.platform.{self.module}:dp-{self.module}':
                    return package_info['id']
        else:
            print(f"Failed to get the package ID: {response.status_code}")
            print(response.text)
    
    def get_latest_version(self, regex_pattern, include_all_versions, package_id):
        url = f'{self.auth_handler.feeds_base_url}/packages/{package_id}?includeAllVersions={include_all_versions}&api-version=7.1-preview.1'
        response = requests.get(url, auth=self.auth_handler.get_auth())
        if response.status_code == 200:
            versions = response.json()['versions']
            matching_jars = [{'version': ver['normalizedVersion'], 'id': ver['id']} for ver in versions if re.match(regex_pattern, ver['normalizedVersion'])]
            if matching_jars:
                return max(matching_jars, key=lambda x: x['version'])
            else:
                print(f"No package found matching the regex pattern '{regex_pattern}' in feed '{self.auth_handler.feed_name}'.")
        else:
            print(f"Failed to list package versions: {response.status_code}")
            print(response.json())

    def get_jar_filename(self, regex_pattern):
        package_id = self.get_package_id()
        include_all_versions = True
        version_info = self.get_latest_version(regex_pattern, include_all_versions, package_id)
        if version_info:
            list_files_url = f'{self.auth_handler.feeds_base_url}/packages/{package_id}/versions/{version_info["id"]}?api-version=7.1-preview.1'
            file_name_regex = fr'.*\b({version_info["version"]}\.jar)\b$'
            response = requests.get(list_files_url, auth=self.auth_handler.get_auth())
            if response.status_code == 200:
                try:
                    files = response.json()['files']
                    for file in files:
                        if re.match(file_name_regex, file['name']):
                            return file['name']
                except ValueError as e:
                    print("Error decoding JSON response:", e)
                    print("Response content:", response.text)
            else:
                print(f"Failed to list files for version {version_info['version']}: {response.status_code}")
                print("Response content:", response.text)
        else:
            print(f"No matching JAR found.")
        return None


# PipelineManager manages fetching, updating, and pushing pipeline JSON files
class PipelineManager:
    def __init__(self, auth_handler, pipeline_name):
        self.auth_handler = auth_handler
        self.pipeline_name = pipeline_name

    def fetch_pipeline(self, branch):
        pipeline_path = f'/pipeline/{self.pipeline_name}.json'
        url = f'{self.auth_handler.base_url}/items?path={pipeline_path}&versionDescriptor.version={branch}&api-version=7.1-preview.1'
        return requests.get(url, auth=self.auth_handler.get_auth())

    def update_pipeline(self, activity_name, update_mode, ls_runtime, new_ls, annotation, replacement_string, utilities):
        branches = [utilities.new_branch_name, utilities.base_branch_name]
        pipeline_content = None

        for branch in branches:
            response = self.fetch_pipeline(branch)
            if response.status_code == 200:
                pipeline_content = response.json()
                break
            elif response.status_code == 401:
                raise Exception(f"Error Code: {response.status_code}, Please check Personal Access Token [Expired OR Incorrect]!!")

        if not pipeline_content:
            return None

        original_hash = utilities.calculate_hash(pipeline_content)
        regex_jar = fr'\b(org-{utilities.module}-)\b.*\b(\.jar)\b'

        for data in pipeline_content['properties']['activities']:
            if data['name'] == activity_name:
                if update_mode in [0, 2, 3]:
                    if update_mode != 2:
                        data['typeProperties'] = utilities.update_json(data['typeProperties'], regex_jar, replacement_string, 'jar')
                    if update_mode in [2, 3]:
                        if not utilities.check_pattern(data['linkedServiceName']['referenceName'], new_ls) and utilities.check_new_ls(self.auth_handler, new_ls):
                            data['linkedServiceName'] = utilities.update_json(data['linkedServiceName'], data['linkedServiceName']['referenceName'], new_ls, 'referenceName')
                        elif not utilities.check_new_ls(self.auth_handler, new_ls):
                            raise Exception(f"Provided New LS {new_ls} does not Exist! Please Verify the name and try again.")
                        else:
                            print("Pipeline contains the input/new LS already.")
                elif update_mode == 1:
                    if not utilities.check_pattern(data['linkedServiceName']['referenceName'], ls_runtime):
                        raise Exception(f"Current LS {data['linkedServiceName']['referenceName']} is not of runtime {ls_runtime}!!")
                    data['typeProperties'] = utilities.update_json(data['typeProperties'], regex_jar, replacement_string, 'jar')
                    print(f"Current LS {data['linkedServiceName']['referenceName']} is of runtime {ls_runtime}.")
                elif update_mode == -1:
                    raise Exception(f"Invalid Value received for {update_mode}!")

        new_hash = utilities.calculate_hash(pipeline_content)
        if original_hash == new_hash:
            raise Exception("No Changes detected! Either the pipeline already contains the latest JAR or the provided module/LS is unchanged.")

        if annotation and annotation not in pipeline_content['properties']['annotations']:
            pipeline_content['properties']['annotations'].append(annotation)

        return json.dumps(pipeline_content, indent=4)

    def push_updated_file(self, updated_content, new_branch_name):
        get_commits_url = f'{self.auth_handler.base_url}/commits?searchCriteria.itemVersion.version={new_branch_name}&$top=1&api-version=7.1-preview.1'
        response = requests.get(get_commits_url, auth=self.auth_handler.get_auth())
        if response.status_code == 200:
            commit_id = response.json()['value'][0]['commitId']
            update_file_url = f'{self.auth_handler.base_url}/pushes?api-version=7.1-preview.2'
            payload = {
                "refUpdates": [{"name": f"refs/heads/{new_branch_name}", "oldObjectId": commit_id}],
                "commits": [{
                    "comment": f"Update '/pipeline/{self.pipeline_name}.json'",
                    "changes": [{"changeType": "edit", "item": {"path": f'/pipeline/{self.pipeline_name}.json'}, "newContent": {"content": updated_content, "contentType": "rawtext"}}]
                }]
            }
            headers = {'Content-Type': 'application/json'}
            response = requests.post(update_file_url, auth=self.auth_handler.get_auth(), headers=headers, json=payload)
            if response.status_code in [200, 201]:
                print("File updated successfully.")
                return True
            else:
                print(f"Failed to update file: {response.status_code}")
                print("Response content:", response.text)
                return False
        else:
            print(f"Failed to get the latest commit from branch '{new_branch_name}': {response.status_code}")
            print("Response content:", response.text)
            return False


# BranchManager handles branch creation and pull request management
class BranchManager:
    def __init__(self, auth_handler, base_branch_name, new_branch_name):
        self.auth_handler = auth_handler
        self.base_branch_name = base_branch_name
        self.new_branch_name = new_branch_name

    def create_branch(self):
        get_commits_url = f'{self.auth_handler.base_url}/commits?searchCriteria.itemVersion.version={self.base_branch_name}&$top=1&api-version=7.1-preview.1'
        response = requests.get(get_commits_url, auth=self.auth_handler.get_auth())
        if response.status_code == 200:
            commit_id = response.json()['value'][0]['commitId']
            create_branch_url = f'{self.auth_handler.base_url}/refs?api-version=7.1-preview.1'
            headers = {'Content-Type': 'application/json'}
            payload = [{
                "name": f"refs/heads/{self.new_branch_name}",
                "oldObjectId": "0000000000000000000000000000000000000000",  # Required to create a new branch
                "newObjectId": commit_id,
            }]
            response = requests.post(create_branch_url, auth=self.auth_handler.get_auth(), headers=headers, json=payload)
            if response.status_code in [200, 201]:
                print(f"Branch '{self.new_branch_name}' created successfully.")
                return True
            else:
                print(f"Failed to create branch '{self.new_branch_name}': {response.status_code}")
                print("Response content:", response.text)
                return False
        else:
            print(f"Failed to get the latest commit from branch '{self.base_branch_name}': {response.status_code}")
            print("Response content:", response.text)
            return False

    def create_pull_request(self, pr_title, pr_description, reviewers):
        pr_list_url = f'{self.auth_handler.base_url}/pullrequests?searchCriteria.sourceRefName=refs/heads/{self.new_branch_name}&searchCriteria.targetRefName=refs/heads/{self.base_branch_name}&api-version=7.1-preview.1'
        headers = {'Content-Type': 'application/json'}

        # Checking for existing PRs
        response = requests.get(pr_list_url, auth=self.auth_handler.get_auth(), headers=headers)
        if response.status_code == 200:
            pr_list = response.json().get('value', [])
            if pr_list:
                # If there's an existing PR, getting its URL
                existing_pr = pr_list[0]  # Taking the first match found
                pr_id = existing_pr['pullRequestId']
                pr_web_url = (f'https://dev.azure.com/{self.auth_handler.organization}/{self.auth_handler.project_name}/_git/{self.auth_handler.repo_name}/pullrequest/{pr_id}').replace(" ", "%20")
                print(f"PR already exists: {pr_web_url}")
                return pr_web_url

        # If no existing PR is found, creating a new one
        pr_url = f'{self.auth_handler.base_url}/pullrequests?api-version=7.1-preview.1'
        pr_payload = {
            "sourceRefName": f"refs/heads/{self.new_branch_name}",
            "targetRefName": f"refs/heads/{self.base_branch_name}",
            "title": pr_title,
            "description": pr_description,
            "reviewers": [{"id": reviewer} for reviewer in reviewers]
        }
        response = requests.post(pr_url, auth=self.auth_handler.get_auth(), headers=headers, json=pr_payload)
        if response.status_code == 201:
            print("Pull Request created successfully.")
            pr_data = response.json()
            pr_id = pr_data['pullRequestId']
            pr_web_url = (f'https://dev.azure.com/{self.auth_handler.organization}/{self.auth_handler.project_name}/_git/{self.auth_handler.repo_name}/pullrequest/{pr_id}').replace(" ", "%20")
            print(f"PR URL: {pr_web_url}")
            return pr_web_url
        else:
            print(f"Failed to create Pull Request: {response.status_code}")
            print("Response content:", response.text)
            return None


# Utilities contains utility functions for JSON handling, hashing, etc.
class Utilities:
    def __init__(self, new_branch_name, base_branch_name, module):
        self.new_branch_name = new_branch_name
        self.base_branch_name = base_branch_name
        self.module = module

    def calculate_hash(self, content):
        content_str = json.dumps(content, sort_keys=True)
        return hashlib.md5(content_str.encode('utf-8')).hexdigest()

    def check_pattern(self, data, pattern):
        return re.search(re.escape(pattern), data)

    def replace_pattern(self, text, pattern, replacement):
        return re.sub(pattern, replacement, text)

    def update_json(self, data, pattern, replacement, key=""):
        if isinstance(data, dict):
            return {k: self.update_json(v, pattern, replacement, key) for k, v in data.items()}
        elif isinstance(data, list):
            return [self.update_json(item, pattern, replacement, key) for item in data]
        elif isinstance(data, str):
            return self.replace_pattern(data, pattern, replacement)
        else:
            return data

    def check_new_ls(self, auth_handler, new_ls):
        ls_path = f'/linkedService/{new_ls}.json'
        url = f'{auth_handler.base_url}/items?path={ls_path}&api-version=7.1-preview.1'
        response = requests.get(url, auth=auth_handler.get_auth())
        if response.status_code == 200:
            return True
        elif response.status_code == 404:
            return False
        else:
            print(f"Failed to check the file: {response.status_code}")
            print("Response content:", response.text)
            return False


# The orchestrator class brings everything together
class UpdatePipelineJarADF:
    def __init__(self, pat, pipeline_name, activity_name, module, new_branch_name, update_mode, ls_runtime, new_ls, annotation, organization='COSMOS-Sephora-Shared', project_name='Getting Started', repo_name='adf-resources', feed_name='sephora-dp', base_branch_name='master'):
        self.auth_handler = AuthenticationHandler(pat, organization, project_name, repo_name, feed_name)
        self.package_manager = PackageManager(self.auth_handler, module)
        self.pipeline_manager = PipelineManager(self.auth_handler, pipeline_name)
        self.branch_manager = BranchManager(self.auth_handler, base_branch_name, new_branch_name)
        self.utilities = Utilities(new_branch_name, base_branch_name, module)
        self.activity_name = activity_name
        self.update_mode = update_mode
        self.ls_runtime = ls_runtime
        self.new_ls = new_ls
        self.annotation = annotation

    def execute(self):
        regex_pattern = r'.*\b(master)\b$'
        jar_filename = self.package_manager.get_jar_filename(regex_pattern)
        if jar_filename:
            updated_content = self.pipeline_manager.update_pipeline(self.activity_name, self.update_mode, self.ls_runtime, self.new_ls, self.annotation, jar_filename, self.utilities)
            if updated_content:
                if self.branch_manager.create_branch():
                    self.pipeline_manager.push_updated_file(updated_content, self.utilities.new_branch_name)
                    self.branch_manager.create_pull_request(f'Merge {self.utilities.new_branch_name} to master', f'Merge {self.utilities.new_branch_name} to master', [])

In [0]:
def main():
    pat = dbutils.widgets.get("PAT")
    module = dbutils.widgets.get("JAR Module")
    pipeline_name = dbutils.widgets.get("Pipeline Name")
    new_branch_name = dbutils.widgets.get("Branch Name")
    activity_name = dbutils.widgets.get("Activity Name")
    
    update_mode = 0 if dbutils.widgets.get("Update Mode") == "JAR Only" else 1 if dbutils.widgets.get("Update Mode") == "JAR and Check LS" else 2 if dbutils.widgets.get("Update Mode") == "LS Only" else 3 if dbutils.widgets.get("Update Mode") == "JAR and LS" else -1
    ls_runtime = "13_3" if dbutils.widgets.get("Linked Service Runtime[To Check]") == "13.3" else "14_3" if dbutils.widgets.get("Linked Service Runtime[To Check]") == "14.3" else None
    new_ls = dbutils.widgets.get("New Linked Service")
    annotation = dbutils.widgets.get("Annotation")

    if module not in ('aggregation', 'ingestion'):
        raise Exception('Invalid Module chosen! Please choose Aggregation OR Ingestion.')

    updater = UpdatePipelineJarADF(pat=pat, pipeline_name=pipeline_name, activity_name=activity_name, module=module, new_branch_name=new_branch_name, update_mode=update_mode, ls_runtime=ls_runtime, new_ls=new_ls, annotation=annotation)
    updater.execute()


if __name__ == "__main__":
    main()