## Imports and globals

In [None]:
from typing import Any, Dict, List, Tuple, Union, Optional
from msticpy.data.drivers.driver_base import DriverBase, QuerySource
from msticpy.common import pkg_config as config
from msticpy.common.exceptions import MsticpyException
from msticpy._version import VERSION
from msal import PublicClientApplication

import pandas as pd
import base64
import re
import json
import requests
import datetime

__version__ = VERSION
__author__ = "Martijn Veken, Ruben Bouman"

## MTP data driver

In [None]:
class MTPDriver(DriverBase):
    """
        KqlDriver class to retreive data from MTP.

        Authentication for MTP API is based on "delegated permissions". The advantage of this is that no API keys need to be saved somewhere, 
        but it's the user who interactively logs on (including MFA when configured).

        API limitations:
        MTP API is limiting query results to 100.000 rows per time. It's also has a limit of 10 calls per minute, 10 minutes of running time 
        every hour and 4 hours of running time a day. The maximal execution time of a single request is 10 minutes. Read more on:

        https://docs.microsoft.com/en-gb/microsoft-365/security/mtp/api-advanced-hunting
    """

    def __init__(self, **kwargs):
        """
        Instantiaite MTPDriver and optionally connect.
        Parameters
        ----------
        connect: bool, optional
            Set true if you want to connect to the provider at initialization
        """
        super().__init__()
        self._access_token = None
        self._debug = kwargs.get("debug", False)

    def connect(self, connection_str: str = None, **kwargs):
        """
        Connect to data source.
        Parameters
        ----------
        connection_str: str, optional
            Connect to a data source, not used in this driver
        """
        app_id = None
        tenant_id = None
        logon_result = None
        endpoint = 'https://api.security.microsoft.com'
        scope = endpoint + '/.default'

        if kwargs:
            cs_dict = kwargs
            app_id =  cs_dict["app_id"] if "app_id" in cs_dict else None
            tenant_id = cs_dict["tenant_id"] if "tenant_id" in cs_dict else None

        if not (app_id and tenant_id):
            raise MsticpyException("Missing parameters.")

        app = PublicClientApplication(
            app_id,
            authority='https://login.microsoftonline.com/' + tenant_id)

        config = {
            'authority': 'https://login.microsoftonline.com/common',
            'client_id': app_id,
            'scope': [scope],
            'endpoint': endpoint
        }

        # First check if there's a token in the cache:
        accounts = app.get_accounts()
        if accounts:
            account = accounts[0]
            logon_result = app.acquire_token_silent([scope], account=account)
            print('Re-using token from cache.\n')

        if not logon_result:
            # Nothing in cache, so get a new token:
            flow = app.initiate_device_flow(scopes=config['scope'])
            if 'user_code' not in flow:
                raise MsticpyException('Fail to create device flow. Err: %s' % json.dumps(flow, indent=4))

            print(flow['message'])
            print('Waiting for authentication...\n')
            logon_result = app.acquire_token_by_device_flow(flow)

        if 'access_token' in logon_result:
            logon_info = json.loads(str(self._decode_base64(logon_result['access_token'].split('.')[1]))[2:-1])
            print('You are succesfully logged in: ')
            print('Name: %s %s' % (logon_info['given_name'], logon_info['family_name']))
            print('UPN:  %s' % logon_info['upn'])
            print('Token expiration: %s ' % datetime.datetime.fromtimestamp(logon_info['exp']).isoformat())
            self._access_token = logon_result["access_token"]
            self._connected = True
            logon_result["access_token"] = None
            return logon_result
        else:
            raise MsticpyException("%s, %s, %s" % (logon_result.get("error"), logon_result.get("error_description"), logon_result.get("correlation_id")))

    def query_with_results(self, query: str, **kwargs) -> Tuple[pd.DataFrame, Any]:  
        """
        Execute query string and return DataFrame of results.
        Parameters
        ----------
        query : str
            The kql query to execute
        Returns
        -------
        Tuple[pd.DataFrame, results.ResultSet]
            A DataFrame (if successfull) and
            Kql ResultSet.
        """
        
        if not self.connected or self._access_token is None: 
            self.connect(self.current_connection)
            if not self.connected:
                raise ConnectionError("Source is not connected. ", "Please call connect() and retry.")

        if self._debug:
            print(query)
        
        url = 'https://api.security.microsoft.com/api/advancedhunting/run'
        headers = { 
            'Content-Type' : 'application/json',
            'Accept' : 'application/json',
            'Authorization' : "Bearer " + self._access_token
        }

        data = json.dumps({ 'Query' : query }).encode("utf-8")

        response = requests.post(url=url, headers=headers, data=data)

        if response.status_code != requests.codes["ok"]:
            if response.status_code == 401:
                raise ConnectionRefusedError(
                    "Authentication failed - possible ", "timeout. Please re-connect."
                )
            # Raise an exception to handle hitting API limits
            if response.status_code == 429:
                raise ConnectionRefusedError("You have likely hit the API limit. ")
            response.raise_for_status()

        json_response = response.json()

        if "Results" in json_response:
            result = json_response["Results"]
        else:
            result = None

        if not result:
            print("Warning - query did not return any results.")
            return None, json_response
        return pd.io.json.json_normalize(result), result

    def query( self, query: str, query_source: QuerySource = None, **kwargs) -> Union[pd.DataFrame, Any]:
        """
        Execute query string and return DataFrame of results.
        Parameters
        ----------
        query : str
            The query to execute
        query_source : QuerySource
            The query definition object
        Returns
        -------
        Union[pd.DataFrame, results.ResultSet]
            A DataFrame (if successfull) or
            the underlying provider result if an error.
        """
        del query_source, kwargs
        return self.query_with_results(query)[0]

    def _decode_base64(self, data, altchars=b'+/'):
        """Decode base64, padding being optional.

        :param data: Base64 data as an ASCII string
        :returns: The decoded string.

        """
        data = re.sub('[^a-zA-Z0-9%s]+' % altchars, '', data)  # normalize
        missing_padding = len(data) % 4
        if missing_padding:
            data += '='* (4 - missing_padding)
        return base64.b64decode(data, altchars)

