In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as mdates


def define_dates (case_name):
    data ={}
    for entry in case_name:

        date=entry['date']
        metric=entry['metric']
        value=entry['metric_value']
        age = entry.get('age', None) 
        if date not in data:
            data[date]={}
        data[date][metric]=value

        if age is not None :
            data[date]['age'] = age

    return data


def parse_date(datestring):
    """ Convert a date string into a pandas datetime object """
    return pd.to_datetime(datestring, format="%Y-%m-%d")


def plot_for_year(year, title, timeseriesdf, marker='o'):
    start_date = pd.Timestamp(f'{year}-01-01')
    end_date = pd.Timestamp(f'{year}-12-31') 
    full_date_range = pd.date_range(start=start_date, end=end_date, freq='ME')

   
    filtered_data = timeseriesdf[(timeseriesdf.index >= start_date) & (timeseriesdf.index <= end_date)]

    
    plt.figure(figsize=(15, 6))
    
    
    if "age" in timeseriesdf.columns:
        
        filtered_data['month'] = filtered_data.index.to_period('M')
        
        # groups metric valus by month so that all the values for each month and for each age range  can be summed up
        aggregated_data = filtered_data.groupby(['month', 'age'])['metric_value'].sum().reset_index()
        
        
        pivoted_data = aggregated_data.pivot_table(
            index='month',            
            columns='age',           
            values='metric_value',   
            aggfunc='sum'            
        ).fillna(0)
        
        print(pivoted_data)

        # Plot the bar chart
        ax = pivoted_data.plot(
            kind='bar', 
            figsize=(15, 8), 
            width=0.8
        )
        
        ax.set_xticks(range(len(pivoted_data.index)))
        ax.set_xticklabels(
            [date.strftime('%b %Y') for date in pivoted_data.index.to_timestamp()],  
            rotation=0,  
            ha='left'
        )
    
        
       
        
    else:
        plt.plot(
            filtered_data.index, 
            filtered_data['positivity_rate'], 
            marker=marker, linestyle='-', 
            label=f'{title} ({year})'
        )
        plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%b %Y'))
        plt.gca().xaxis.set_major_locator(mdates.MonthLocator())

    
    plt.title(f'{title} {year}', fontsize=16)
    plt.xlabel('')
    plt.ylabel('')

   
   

    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.legend()

    plt.tight_layout()
    plt.show()