In [None]:
import os
import ast
import time
import math
import warnings
import pandas as pd
import numpy as np
from functools import wraps

from geopandas.tools import geocode
from geopy.geocoders import Nominatim
from rapidfuzz import fuzz
import geopandas as gpd
import matplotlib.pyplot as plt
import matplotlib.lines as mlines

In [None]:
# -------------------------------------------------------------------
# Wrapper Methods
# -------------------------------------------------------------------
def retry(retries: int = 3, delay: int = 2):
    """
    Decorator to retry the wrapped function upon exception.

    Parameters:
        retries (int): Number of attempts before giving up.
        delay (int): Delay (in seconds) between attempts.

    Returns:
        A decorator that retries the function call.
    """
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            attempt = 0
            while attempt < retries:
                try:
                    return func(*args, **kwargs)
                except Exception as e:
                    print(f"Attempt {attempt + 1} failed in {func.__name__}: {e}")
                    attempt += 1
                    time.sleep(delay)
            # Return fallback values after all retries fail.
            return (None, None, None)
        return wrapper
    return decorator

def timeit(func):
    """
    Decorator to measure and print the execution time of a function.
    
    Returns:
        The decorated function.
    """
    @wraps(func)
    def wrapper(*args, **kwargs):
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=FutureWarning)
            start_time = time.time()
            result = func(*args, **kwargs)
            end_time = time.time()
            execution_time = end_time - start_time
            if not hasattr(wrapper, "has_printed_header"):
                print(f"+{'-'*30}+{'-'*20}+")
                print(f"| {'Function':<28} | {'Execution Time':<16} |")
                print(f"+{'-'*30}+{'-'*20}+")
                wrapper.has_printed_header = True
            print(f"| {func.__name__:<28} | {execution_time:.4f} seconds{'':<5} |")
            print(f"+{'-'*30}+{'-'*20}+")
        return result
    return wrapper


# -------------------------------------------------------------------
# Geocoding of Port data
# -------------------------------------------------------------------
class PortGeoCoding:
    """
    A class for geocoding ports from a CSV database.

    This class loads port data, enriches it with primary and backup geo queries,
    and provides methods for retrieving a DataFrame containing Port, Country,
    Longitude, and Latitude. Geocoding is performed using Nominatim via geopy,
    with retry logic, timing reports, and fuzzy matching to verify that the
    returned country matches the expected country.
    """

    def __init__(self, csv_path: str):
        """
        Initializes the PortGeoCoding instance by reading the CSV file.

        Parameters:
            csv_path (str): The path to the CSV file containing port data.
        """
        self._ports = pd.read_csv(csv_path)
        self._enrich_ports()
    
    @staticmethod
    def _serializer(df: pd.DataFrame, file_name: str, file_format: str = None) -> None:
        """
        Export the provided DataFrame to a file in the designated export directory.
    
        This function checks for the existence of a directory named "datasets_processed"
        in the current directory or in the parent directory. If the directory exists, it is
        used; otherwise, the directory is created in the current directory. The DataFrame is
        then exported to a file with the specified file name and format in that directory.
        
        The export format can be Excel, CSV, or JSON. The format is determined by:
          - The file extension in 'file_name' (e.g., ".xlsx", ".csv", ".json"), or
          - The explicit 'file_format' parameter if provided.
          
        Parameters:
            df (pd.DataFrame): The DataFrame to be exported.
            file_name (str): The name of the output file (e.g., "output.xlsx", "output.csv", or "output.json").
            file_format (str, optional): The format to export to ("excel", "csv", or "json"). 
                                         If not provided, the format is inferred from the file extension.
        
        Returns:
            None
        """
        if os.path.exists("datasets_processed"):
            export_dir = "datasets_processed"

        elif os.path.exists("../datasets_processed"):
            export_dir = "../datasets_processed"

        else:
            export_dir = "datasets_processed"
            os.makedirs(export_dir, exist_ok=True)
        
        if file_format is None:
            file_format = file_name.split('.')[-1].lower()
        
        file_path = os.path.join(export_dir, file_name)
        
        if file_format in ['xlsx', 'excel']:
            df.to_excel(file_path, index=True)

        elif file_format == 'csv':
            df.to_csv(file_path, index=True)

        elif file_format == 'json':
            df.to_json(file_path, orient='records', lines=True)

        else:
            raise ValueError(f"Unsupported file format: {file_format}")
        
    def _is_country_match(self, expected: str, returned: str, threshold: int = 80) -> bool:
        """
        Checks whether the expected country and the returned country match.
        It first normalizes the returned country. Then it uses fuzzy matching
        (via RapidFuzz) to compute a similarity score. If the score is above
        the threshold, they are considered a match.

        Parameters:
            expected (str): The expected country (e.g., from the CSV).
            returned (str): The country returned by the geocoder.
            threshold (int): The minimum similarity percentage required.

        Returns:
            bool: True if the countries match (fuzzily); False otherwise.
        """
        expected = expected.strip().lower()
        returned = returned.strip().lower()
        
        special_cases = {
            "korea": "south korea",
            "usa": "united states",
            "hong kong": "china"
        }
    
        expected = special_cases.get(expected, expected)
        returned = special_cases.get(returned, returned)
            
        score = fuzz.ratio(expected, returned)
        return score >= threshold

    def _enrich_ports(self):
        """
        Enriches the ports DataFrame by adding:
          - 'Geo_Query': A primary query built from 'Port Name' and 'Country'.
          - 'Backup_Queries': A list of backup queries parsed from 'Also known as'.
        """
        self._ports["Geo_Query"] = self._ports.apply(
            lambda row: f'{row["Port Name"].title().replace("Anch", "").strip()}, '
                        f'{row["Country"].title()}', axis=1
        )
        self._ports["Backup_Queries"] = self._ports.apply(self.get_backup_queries, axis=1)

    def get_backup_queries(self, row: pd.Series) -> list:
        """
        Generates backup geo queries from a row using the 'Also known as' field.

        Parameters:
            row (pd.Series): A row from the ports DataFrame.

        Returns:
            list: A list of backup query strings.
        """
        backup_queries = []
        aka = row.get("Also known as", "")
        if pd.isna(aka) or not aka.strip():
            return backup_queries

        try:
            alt_names = ast.literal_eval(aka)
            if not isinstance(alt_names, list):
                alt_names = [alt_names]
                
        except Exception:
            alt_names = [aka]

        for alt in alt_names:
            alt_clean = alt.strip()
            if alt_clean and alt_clean != '-':
                backup_queries.append(f"{alt_clean.title().replace('Anch', '').strip()}"
                                      f", {row['Country'].title()}")
        return backup_queries

    @retry(retries=3, delay=1)
    def get_coordinates(self, in_location: str, silent: bool = True) -> (float, float, str):
        """
        Retrieves the latitude, longitude, and country for a given location using Nominatim.

        Parameters:
            in_location (str): The location query string.
            silent (bool):     Surpresses prints, defaults to True.
            
        Returns:
            tuple: (latitude, longitude, country) if found, otherwise (None, None, None).
        """
        geolocator = Nominatim(user_agent="port_geo_coding")
        location = geolocator.geocode(in_location, addressdetails=True, language="en")
        if location is not None:
            address = location.raw.get('address', {})
            country = address.get('country', None)
        else:
            country = None
        
        if not silent:
            print(f"+{'-'*30}+{'-'*30}+")
            print(f"| {'Input Location':<28} | {'Mapped Country':<28} |")
            print(f"+{'-'*30}+{'-'*30}+")
            print(f"| {in_location:<28} | {country if country else 'None':<28} |")
            print(f"+{'-'*30}+{'-'*30}+")
            
        if not location:
            return (None, None, None)
        
        return (location.latitude, location.longitude, country)

    @timeit
    def yield_geocoded_ports(self, export: bool = True) -> pd.DataFrame:
        """
        Geocodes each port and returns a DataFrame with 'Port', 'Country', 'Longitude', and 'Latitude'.

        For each port, the method:
          1. Attempts geocoding using the primary 'Geo_Query'.
          2. Validates the result by ensuring coordinates are present and that the
             returned country (after normalization) fuzzily matches the expected country.
          3. If invalid, iterates over backup queries until a valid result is found.
          4. If no valid result is found, leaves the coordinates as None.
        
        Parameters:
            export (bool): Indicator whether the computed dataframe shall be exported
                           defaults to True (meaning gets exported to an excel file).
        
        Returns:
            pd.DataFrame: A DataFrame containing geocoded port data.
        """
        results = []
        for idx, row in self._ports.iterrows():
            expected_country = row["Country"].title()
            primary_query = row["Geo_Query"]
            lat, lon, country = self.get_coordinates(primary_query)
            
            valid = True
            if lat is None or lon is None or country is None:
                valid = False
            else:

                if not self._is_country_match(expected_country, country):
                    valid = False
            
            if not valid:
                for backup in row["Backup_Queries"]:
                    lat, lon, country = self.get_coordinates(backup)
                    if lat is not None and lon is not None and country is not None:
                        if self._is_country_match(expected_country, country):
                            valid = True
                            break
                        else:
                            valid = False
            
            results.append({
                "Port": row["Port Name"].title().replace("Anch", "").strip(),
                "Country": expected_country,
                "Longitude": lon,
                "Latitude": lat
            })
        
        df_geocoded = pd.DataFrame(results)
        
        if export:
            self._serializer(df=df_geocoded, file_name="port_data_geocoded.xlsx", file_format="xlsx")
                
        return pd.DataFrame(results)

    def _haversine(self, lon1: float, lat1: float, lon2: float, lat2: float) -> float:
        """
        Calculate the great-circle distance between two points on the Earth
        specified in decimal degrees using the Haversine formula.
        
        Parameters:
            lon1 (float): Longitude of the first point.
            lat1 (float): Latitude of the first point.
            lon2 (float): Longitude of the second point.
            lat2 (float): Latitude of the second point.
        
        Returns:
            float: Distance between the two points in kilometers.
        """
        lon1, lat1, lon2, lat2 = map(math.radians, [lon1, lat1, lon2, lat2])
        
        dlon = lon2 - lon1
        dlat = lat2 - lat1
        
        a = math.sin(dlat / 2)**2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon / 2)**2
        c = 2 * math.asin(math.sqrt(a))
        
        r = 6371
        return c * r
    
    def compute_distance_matrix(self, df: pd.DataFrame, export: bool = True)-> pd.DataFrame:
        """
        Computes the distance matrix for all ports covered based on the
        haversince distance.
        
        Parameters:
            df (pd.DataFrame): The plain port dataframe containing long and lat data.
            export (bool):     Indicator whether the computed dataframe shall be exported
                               defaults to True (meaning gets exported to an excel file).
                               
        Returns:
            pd.DataFrame: The distance matrix.
        """
        df = df.dropna(subset=["Longitude", "Latitude"]).reset_index()
        
        num_ports = df.shape[0]
        
        distance_matrix = np.zeros((num_ports, num_ports))
        
        for i in range(num_ports):
            for j in range(i, num_ports):
                lon1, lat1 = df.loc[i, 'Longitude'], df.loc[i, 'Latitude']
                lon2, lat2 = df.loc[j, 'Longitude'], df.loc[j, 'Latitude']
                distance = self._haversine(lon1, lat1, lon2, lat2)
                distance_matrix[i, j] = distance
                distance_matrix[j, i] = distance
        
        distance_df = pd.DataFrame(distance_matrix, 
                                   index=df['Port'], 
                                   columns=df['Port'])
        
        if export:
            self._serializer(df=distance_df, file_name="distance_matrix.xlsx", file_format="xlsx")
        
        return distance_df
        
    def visualize_ports(self, df: pd.DataFrame) -> None:
        """
        Visualize port locations on a world map and export the resulting image.
    
        This function takes a DataFrame with port data that includes at least the following columns:
          - 'Longitude': The longitude of the port.
          - 'Latitude': The latitude of the port.
          - 'Country': The country in which the port is located.
          
        The function performs the following steps:
          1. Loads a base world map using Natural Earth's low-resolution dataset.
          2. Converts the port coordinates into a GeoDataFrame.
          3. Determines which countries have at least one port data point and lightly shades those countries.
          4. Plots the world map, the shaded countries, and overlays the port locations with custom styling.
          5. Adds a title, axis labels, grid, legend, and other stylistic elements.
          6. Exports the final visualization to a PNG file in an "exports" directory. The directory is searched for in
             the current directory and its parent; if not found, an "exports" directory is created in the current directory.
          7. Displays the final plot.
    
        Parameters:
            df (pd.DataFrame): A DataFrame containing the port data with the columns 'Longitude', 'Latitude', and 'Country'.
    
        Returns:
            None
        """
        world = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))
        
        gdf_ports = gpd.GeoDataFrame(
            df, geometry=gpd.points_from_xy(df.Longitude, df.Latitude)
        )
        
        port_countries = df["Country"].unique()
        
        mapping = {
            'usa': 'United States of America',
            'korea': 'South Korea'
        }
        
        port_countries_adjusted = [mapping.get(country.lower(), country) for country in port_countries]
        mask = world['name'].str.lower().isin([c.lower() for c in port_countries_adjusted])
        
        shaded_countries = world[mask]
        
        plt.style.use('seaborn-whitegrid')
        
        fig, ax = plt.subplots(figsize=(25, 16))
        
        world.plot(ax=ax, color='#f5f5f5', edgecolor='#bdbdbd')
        
        shaded_countries.plot(ax=ax, color='#d1e5f0', edgecolor='#bdbdbd', alpha=0.8)
        
        gdf_ports.plot(
            ax=ax,
            marker='o',
            color='red',
            markersize=80,
            alpha=0.85,
            edgecolor='black',
            linewidth=0.5,
            label='Port'
        )
        
        title_text = 'Port Locations'
        ax.set_title(title_text, fontsize=28, fontweight='bold', pad=20)
        ax.set_xlabel('Longitude', fontsize=18)
        ax.set_ylabel('Latitude', fontsize=18)
        ax.tick_params(axis='both', which='major', labelsize=16)
        ax.grid(True, linestyle='--', alpha=0.5)
        
        port_marker = mlines.Line2D([], [], color='red', marker='o',
                                    linestyle='None', markersize=10, label='Port')
        
        ax.legend(handles=[port_marker], fontsize=16, loc='upper left')
        
        for spine in ax.spines.values():
            spine.set_visible(False)
        
        plt.tight_layout()
        
        if os.path.exists("exports"):
            export_dir = "exports"
            
        elif os.path.exists("../exports"):
            export_dir = "../exports"
            
        else:
            export_dir = "exports"
            os.makedirs(export_dir, exist_ok=True)
        
        export_path = os.path.join(export_dir, "port_locations.png")
        plt.savefig(export_path, dpi=300, bbox_inches='tight')
        
        plt.show()    

In [None]:
csv_file = "../data/port_data/shipping_ports_around_the_world/port_data.csv"
port_geocoder = PortGeoCoding(csv_file)

df_geo = port_geocoder.yield_geocoded_ports()
port_geocoder.compute_distance_matrix(df_geo)