# Cross border electricity trading visualization

In [None]:
import os
import multiprocessing
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import matplotlib.pyplot as plt


In [None]:
data_dir = '../database/'
df = pd.read_parquet(data_dir + 'latest.parquet')

In [None]:
df = df[:pd.Timestamp(datetime.today(),tz='UTC')]

In [None]:
flow_columns = df.filter(like='_flow')
flow_columns = flow_columns.rename(columns=lambda x: x.replace('_flow', ''))
flow_columns

In [None]:
flow_columns.describe()

In [None]:
# Set font globally
plt.rcParams['font.family'] = 'serif'
# plt.rcParams['font.size'] = 12
# plt.rcParams['font.weight'] = 'bold'

def plot_snapshot(df: pd.DataFrame, time_index:int, n:int=24, show:bool=False, figname:str=None):
    
    if os.path.isfile(figname):
        print(f"File exists {figname}")
        return

    df = df.iloc[:time_index+1]  # Select data up to current time index
    
    # Setup the plot
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Marker sizes - most recent data point is the largest, older points are smaller
    marker_sizes = np.logspace(0, 2.5, num=n)
    alphas = np.logspace(-1, 0, num=n)
    
    # Color map based on values for color coding the data points
    cmap = plt.get_cmap("coolwarm")

    # Sorting columns by sum of absolute values
    sorted_columns = df.abs().sum().sort_values(ascending=True).index
    df_sorted = df[sorted_columns]
    
    # Assuming that the dataframe has a datetime index and country names as columns
    countries = df_sorted.columns
    times = df_sorted.index
    
    for country in countries:
                
        # Get data for the last n time points (most recent first)
        recent_data = df_sorted[country].tail(n)

        # Normalize flow for color coding
        normalized_values = (recent_data.abs() - recent_data.abs().min().min()) / (recent_data.abs().max().max() - recent_data.abs().min().min())
    
        # Plot each point with a decreasing marker size for older data points
        for i, (time, value) in enumerate(recent_data.items()):
            country=str( str(country[0]).upper()+str(country[1:]) )
            ax.scatter(value, country, s=marker_sizes[i], color=cmap(normalized_values[i]), edgecolor='k', alpha=alphas[i])

    ax.axvline(x=0, color='k', linestyle='--')    
    # Set axis labels and title
    ax.set_xlabel('Import/Export [MW]', fontsize=15)
    ax.set_title(f'Electricity trade between Germany and its neighbours ({times[-1].date()})', fontsize=15)
    ax.set_xlim(-1500,1500)
    ax.tick_params(axis='both', which='both', labelsize=15, 
                   direction='in')
    ax.minorticks_on()

    # Add 'IMPORT' text in a gray box using figure coordinates at the bottom-left corner
    plt.text(0.2, 0.15, 'IMPORT', fontsize=12, color='black',
             bbox=dict(facecolor='white', alpha=0.5),
             verticalalignment='bottom', horizontalalignment='left',
             transform=plt.gcf().transFigure)
    
    # Add 'EXPORT' text in a gray box using figure coordinates at the bottom-right corner
    plt.text(0.9, 0.15, 'EXPORT', fontsize=12, color='black',
             bbox=dict(facecolor='white', alpha=0.5),
             verticalalignment='bottom', horizontalalignment='right',
             transform=plt.gcf().transFigure)
    
    # Show plot
    plt.tight_layout()
    if show:
        plt.show()
    if not figname is None:
        fig.savefig(figname)
    plt.close(fig)

def run_movie(df: pd.DataFrame, n=72, m=24):
    selected_times = np.linspace(n, len(df)-1, int((len(df)-n)/m), dtype=int)
    output_dir = './frames/'
    os.makedirs(output_dir, exist_ok=True)

    # Prepare arguments for multiprocessing
    args_list = []
    for i, time_index in enumerate(selected_times):
        # time_str = df.index[time_index].strftime('%Y%m%d_%H%M')
        # print(time_str)
        figname = f'{output_dir}frame_{i:08d}.png'
        
        args_list.append(
            (df, time_index, n, False, figname)
        )

    # Use multiprocessing Pool to generate frames in parallel
    with multiprocessing.Pool(processes=4) as pool:
        pool.starmap(plot_snapshot, args_list)

    # Determine the framerate to keep the video within LinkedIn's length limit
    num_frames = len(args_list)
    max_video_length_sec = 600  # LinkedIn's maximum video length in seconds
    framerate = min(24, num_frames / max_video_length_sec)  # Limit framerate to 24 FPS or less

    # Create the video using ffmpeg
    # Ensure the video resolution is acceptable (e.g., 1280x720)
    os.system(f'ffmpeg -y -framerate {framerate} -i {output_dir}frame_%08d.png -s:v 1280x720 -c:v libx264 -pix_fmt yuv420p electricity_trade.mp4')
    
run_movie(df=flow_columns)