<a href="https://colab.research.google.com/github/vivek6311/Artificial-Intelligence-with-Python/blob/master/Sankey_Diagram_for_India's_Trade.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [11]:
import plotly.graph_objects as go
import numpy as np

# Data provided by the user for India's top 10 export and import partners in 2024
# The data is structured as dictionaries for easy processing.

# Data for India's exports to various countries
india_exports = {
    'United States': 79.44,
    'United Arab Emirates': 37.10,
    'Netherlands': 24.22,
    'Singapore': 15.62,
    'China': 14.90,
    'United Kingdom': 13.96,
    'Saudi Arabia': 12.10,
    'Bangladesh': 11.32,
    'Germany': 11.00,  # Averaged the approximate value
    'Italy': 10.00   # Averaged the approximate value
}
export_items = {
    'United States': ['Mineral fuels & oils', 'Electrical machinery', 'Pharmaceuticals', 'Engineering goods', 'Textiles'],
    'United Arab Emirates': ['Mineral fuels & oils', 'Pearls & precious stones', 'Electrical machinery', 'Machinery', 'Pharmaceuticals'],
    'Netherlands': ['Mineral fuels', 'Pharmaceuticals', 'Engineering goods', 'Gems & jewelry', 'Organic chemicals'],
    'Singapore': ['Electronic goods', 'Petroleum products', 'Pharmaceuticals', 'Machinery', 'Textiles'],
    'China': ['Minerals & metals', 'Cotton yarn/fabrics', 'Spices', 'Fruits', 'Pharmaceuticals'],
    'United Kingdom': ['Pharmaceuticals', 'Engineering goods', 'Gems & jewelry', 'Textiles', 'Organic chemicals'],
    'Saudi Arabia': ['Petroleum products', 'Pharmaceuticals', 'Organic chemicals', 'Textiles', 'Machinery'],
    'Bangladesh': ['Textiles & garments', 'Pharmaceuticals', 'Engineering goods', 'Leather products', 'Ceramic products'],
    'Germany': ['Machinery', 'Pharmaceuticals', 'Electrical machinery', 'Pearls & gems', 'Organic chemicals'],
    'Italy': ['Gems & jewelry', 'Pharmaceuticals', 'Engineering goods', 'Agricultural products', 'Machinery']
}

# Data for India's imports from various countries
india_imports = {
    'China': 101.75,
    'Russia': 61.43,
    'United Arab Emirates': 48.02,
    'United States': 40.77,
    'Saudi Arabia': 31.81,
    'Iraq': 30.00,
    'Indonesia': 23.41,
    'Switzerland': 21.24,
    'Singapore': 21.20,
    'South Korea': 21.14
}
import_items = {
    'China': ['Electrical machinery (telecom parts)', 'Electronic goods', 'Machinery', 'Chemicals', 'Plastics'],
    'Russia': ['Mineral fuels/crude oil', 'Coal', 'Fertilizers', 'Precious metals', 'Iron & steel'],
    'United Arab Emirates': ['Mineral fuels', 'Precious stones', 'Machinery', 'Electrical machinery', 'Chemicals'],
    'United States': ['Mineral fuels', 'Machinery', 'Electrical goods', 'Chemicals', 'Organic chemicals'],
    'Saudi Arabia': ['Petroleum crude & products', 'Chemicals', 'Fertilizers', 'Plastics', 'Iron & steel'],
    'Iraq': ['Crude oil', 'Petroleum products', 'Fertilizers', 'Chemicals', 'Iron & steel'],
    'Indonesia': ['Coal', 'Crude oil', 'Palm oil', 'Rubber', 'Electrical products'],
    'Switzerland': ['Precious metals', 'Pharmaceutical products', 'Chemicals', 'Machinery'],
    'Singapore': ['Petroleum products', 'Electrical machinery', 'Chemicals', 'Precision instruments'],
    'South Korea': ['Electronics', 'Machinery', 'Chemicals', 'Automobiles']
}


def create_sankey_diagram(exports, imports, export_details, import_details):
    """
    Creates a Sankey diagram to visualize trade flows with detailed item lists.

    Args:
        exports (dict): A dictionary of export partners and their trade values.
        imports (dict): A dictionary of import partners and their trade values.
        export_details (dict): A dictionary mapping export partners to their top items.
        import_details (dict): A dictionary mapping import partners to their top items.

    Returns:
        plotly.graph_objects.Figure: The Sankey diagram figure.
    """
    # Calculate total exports and imports
    total_exports = sum(exports.values())
    total_imports = sum(imports.values())

    # Create a list of all unique countries involved in trade, including 'India'.
    all_countries = sorted(list(set(exports.keys()) | set(imports.keys())))
    all_countries.insert(0, 'India')

    # Create a mapping from country name to an integer index.
    country_to_index = {country: i for i, country in enumerate(all_countries)}

    # Initialize lists to hold Sankey diagram data
    sources = []
    targets = []
    values = []
    link_labels = []
    link_colors = []

    # Process export data (flows from India to partners)
    for country, value in exports.items():
        sources.append(country_to_index['India'])
        targets.append(country_to_index[country])
        values.append(value)
        top_items = ", ".join(export_details.get(country, []))
        link_labels.append(f'India to {country}<br>Value: ${value}B<br>Items: {top_items}')
        link_colors.append('#800000') # Maroon for exports

    # Process import data (flows from partners to India)
    for country, value in imports.items():
        sources.append(country_to_index[country])
        targets.append(country_to_index['India'])
        values.append(value)
        top_items = ", ".join(import_details.get(country, []))
        link_labels.append(f'{country} to India<br>Value: ${value}B<br>Items: {top_items}')
        link_colors.append('#008000') # Green for imports

    # Assign colors to the nodes (countries) for visual distinction
    node_colors = ['#1f77b4'] + ['#ff7f0e'] * (len(all_countries) - 1)

    # Create the Sankey diagram figure
    fig = go.Figure(data=[go.Sankey(
        node=dict(
            pad=15,
            thickness=20,
            line=dict(color="black", width=0.5),
            label=all_countries,
            color='#1f77b4'
        ),
        link=dict(
            source=sources,
            target=targets,
            value=values,
            color=link_colors,
            hovertemplate='%{label}<extra></extra>'
        )
    )])

    # Add annotations for total trade values and labels for imports/exports
    annotations = [
        dict(
            x=0.15, y=1.0,  # Moved down by 2 notches from 1.05
            xref="paper", yref="paper",
            text="Imports",
            showarrow=False,
            font=dict(color="#008000", size=14, weight='bold')
        ),
        dict(
            x=0.85, y=1.025,  # Moved up by 1 notch from 1.0
            xref="paper", yref="paper",
            text="Exports",
            showarrow=False,
            font=dict(color="#800000", size=14, weight='bold')
        ),
        dict(
            x=0.5, y=-0.05,  # Position below the graph for total values
            xref="paper", yref="paper",
            text=f'Total Imports: ${total_imports:.2f}B<br>Total Exports: ${total_exports:.2f}B',
            showarrow=False,
            font=dict(color="#000000", size=14, weight='bold'),
            align='center'
        )
    ]

    # Sort exports and imports by value to align annotations
    sorted_exports = sorted(exports.items(), key=lambda item: item[1], reverse=True)
    sorted_imports = sorted(imports.items(), key=lambda item: item[1], reverse=True)

    # Create an evenly spaced set of y positions for the annotations
    export_y_positions = np.linspace(0.95, 0.05, len(sorted_exports))
    import_y_positions = np.linspace(0.95, 0.05, len(sorted_imports))

    # Add annotations for exports on the right side
    for i, (country, value) in enumerate(sorted_exports):
        annotations.append(dict(
            x=0.9, y=export_y_positions[i],
            xref='paper', yref='paper',
            text=f'<b>{country}</b><br>${value}B',
            showarrow=False,
            font=dict(size=10, color='black'),
            align='left',
            xanchor='left',
            yanchor='middle'
        ))

    # Add annotations for imports on the left side
    for i, (country, value) in enumerate(sorted_imports):
        annotations.append(dict(
            x=0.1, y=import_y_positions[i],
            xref='paper', yref='paper',
            text=f'<b>{country}</b><br>${value}B',
            showarrow=False,
            font=dict(size=10, color='black'),
            align='right',
            xanchor='right',
            yanchor='middle'
        ))

    # Set the layout and title of the Sankey diagram
    fig.update_layout(
        title_text="<b>India's Top Trade Partners in 2024: A Sankey Diagram of Imports & Exports</b>",
        title_font_color="red",
        title_font_size=24,
        font_size=12,
        font_family="Arial",
        paper_bgcolor="#ffffff",
        plot_bgcolor="#f8f9fa",
        margin=dict(l=20, r=20, t=50, b=20),
        annotations=annotations,
        width=1500,  # Increased width for better spacing
        height=900,   # Increased height for better spacing
        title_x=0.5 # Centering the title
    )

    # Display the figure
    fig.show()

# Run the function to generate and display the Sankey diagram
create_sankey_diagram(india_exports, india_imports, export_items, import_items)
