In [60]:
class SeasonalCycle: 
    '''
    class SeasonalCycle

    Derive and plot the seasonal distribution of Chl-a and NPP of the four PFTs

    '''

    def __init__(self,resultpath,savepath,mesh,first_year,last_year,type,
                 mapproj='rob',
                 cmap = 'viridis',
                 savefig=False,
                 verbose=True,
                 plotting=True,
                 output=True,
                 runname='fesom'):

        self.runname = runname
        self.resultpath = resultpath
        self.savepath = savepath
        self.mesh = mesh
        self.fyear = first_year
        self.lyear = last_year
        self.type = type
        self.mapproj = mapproj
        self.cmap = cmap
        self.savefig = savefig
        self.verbose = verbose
        self.plotting = plotting
        # self.frequency = frequency
        self.output = output

        
        import matplotlib.pyplot as plt
        import numpy as np
        import skill_metrics as sm
        import cartopy.crs as ccrs
        import pyfesom2 as pf
        from pathlib import Path
        from netCDF4 import Dataset
        

        if self.mapproj == 'rob':
            box=[-180, 180, -90, 90]
        elif self.mapproj == 'pc':
            box=[-180, 180, -90, 90]
        elif self.mapproj == 'sp':
            box=[-180, 180, -90, -30]
        elif self.mapproj == 'np':
            box=[-180, 180, 60, 90]
            
        self.mapproj = pf.get_proj(self.mapproj)

        
        # set indexes to selected regions: 
        # Arctic:
        Arc1 = np.squeeze(np.where((mesh.y2>=60) & (mesh.y2<70)))
        Arc2 = np.squeeze(np.where((mesh.y2>=70) & (mesh.y2<80)))
        Arc3 = np.squeeze(np.where((mesh.y2>=80) & (mesh.y2<90)))
        # SO:
        SO1 = np.squeeze(np.where((mesh.y2<=-40) & (mesh.y2>-50)))
        SO2 = np.squeeze(np.where((mesh.y2<=-50) & (mesh.y2>-60)))
        SO3 = np.squeeze(np.where((mesh.y2<=-60) & (mesh.y2>-70)))

        regions = [("Arc1", Arc1), ("Arc2", Arc2), ("Arc3", Arc3), ("SO1", SO1), ("SO2", SO2), ("SO3", SO3)]

        # load FESOM data ---------------------------------------------------------------------------------------

        self.years = np.arange(self.fyear, self.lyear+1,1)

        self.months = np.arange(0,12)

 
        # ==============================================================================
        # Loading Chl-a/NPP data
        
        Phy_seasonal = {}
        for region,ind in regions: 
            Phy_seasonal[f'Phy_seasonal_{region}'] = []
            
        Dia_seasonal = {}
        for region,ind in regions: 
            Dia_seasonal[f'Dia_seasonal_{region}'] = []

        Cocco_seasonal = {}
        for region,ind in regions: 
            Cocco_seasonal[f'Cocco_seasonal_{region}'] = []

        Phaeo_seasonal = {}  
        for region,ind in regions: 
            Phaeo_seasonal[f'Phaeo_seasonal_{region}'] = []

        
        for year in self.years:

            if self.type == 'Chl':
                
                # Small Phytoplankton:                
                phy_path = Path(self.resultpath + '/PhyChl.fesom.'+str(year)+'.nc')
                phy_data = Dataset(phy_path,'r')
                PhyChl = phy_data.variables['PhyChl'][:]
    
                for region,ind in regions: 
                    Phy_seasonal[f'Phy_seasonal_{region}'].append(self.get_seansonal_data(PhyChl, ind)) 
                
                #print(Phy_seasonal[f'Phy_seasonal_Arc1'])
                
                # Diatoms:
                dia_path = Path(self.resultpath + '/DiaChl.fesom.'+str(year)+'.nc')
                dia_data = Dataset(dia_path,'r')
                DiaChl = dia_data.variables['DiaChl'][:]
    
                for region,ind in regions: 
                    Dia_seasonal[f'Dia_seasonal_{region}'].append(self.get_seansonal_data(DiaChl, ind))
                
                # Coccos: 
                cocco_path = Path(self.resultpath + '/CoccoChl.fesom.'+str(year)+'.nc') # assuming that coccos were used for the entire simulation if they were used in the first year of simulation
                
                if cocco_path.is_file():
                    cocco_data = Dataset(cocco_path,'r')
                    CoccoChl = cocco_data.variables['CoccoChl'][:]
    
                    for region,ind in regions: 
                        Cocco_seasonal[f'Cocco_seasonal_{region}'].append(self.get_seansonal_data(CoccoChl, ind))

                # Phaeo: 
                phaeo_path = Path(self.resultpath + '/PhaeoChl.fesom.'+str(year)+'.nc') # assuming that phaeo was used for the entire simulation if they were used in the first year of simulation
    
                if phaeo_path.is_file():
                    phaeo_data = Dataset(phaeo_path,'r')
                    PhaeoChl = phaeo_data.variables['PhaeoChl'][:]
    
                    for region,ind in regions: 
                        Phaeo_seasonal[f'Phaeo_seasonal_{region}'].append(self.get_seansonal_data(PhaeoChl, ind))
                #_____________________
                # Set label with unit:
                ylabel = 'Chl.a [mg m$^{-3}$]'
                
            
            elif self.type == 'NPP':

                # Small Phytoplankton:                 
                phy_path = Path(self.resultpath + '/NPPn.fesom.'+str(year)+'.nc')
                phy_data = Dataset(phy_path,'r')
                PhyNPP = phy_data.variables['NPPn'][:]
    
                for region,ind in regions: 
                    Phy_seasonal[f'Phy_seasonal_{region}'].append(self.get_seansonal_data(PhyNPP, ind))

                
                # Diatoms:
                dia_path = Path(self.resultpath + '/NPPd.fesom.'+str(year)+'.nc')
                dia_data = Dataset(dia_path,'r')
                DiaNPP = dia_data.variables['NPPd'][:]
    
                for region,ind in regions: 
                    Dia_seasonal[f'Dia_seasonal_{region}'].append(self.get_seansonal_data(DiaNPP, ind))

                # Coccos:                 
                cocco_path = Path(self.resultpath + '/NPPc.fesom.'+str(year)+'.nc') # assuming that coccos were used for the entire simulation if they were used in the first year of simulation
                
                if cocco_path.is_file():
                    cocco_data = Dataset(cocco_path,'r')
                    CoccoNPP = cocco_data.variables['NPPc'][:]
    
                    for region,ind in regions: 
                        Cocco_seasonal[f'Cocco_seasonal_{region}'].append(self.get_seansonal_data(CoccoNPP, ind))

                # Phaeo:                 
                phaeo_path = Path(self.resultpath + '/NPPp.fesom.'+str(year)+'.nc') # assuming that phaeo was used for the entire simulation if they were used in the first year of simulation
    
                if phaeo_path.is_file():
                    phaeo_data = Dataset(phaeo_path,'r')
                    PhaeoNPP = phaeo_data.variables['NPPp'][:]
    
                    for region,ind in regions: 
                        Phaeo_seasonal[f'Phaeo_seasonal_{region}'].append(self.get_seansonal_data(PhaeoNPP, ind))
                #______________________-
                # Set label with unit:
                ylabel = 'NPP [mg C m$^{-2}$ d$^{-1}$]'
            

            else:
                return("Please select 'CHl' or 'NPP'.")

        #==================================
        # Adjusting data format

        #Small phytoplankton:
        
        Phy_seasonal_colwise = {}
        Phy_seasonal_mean = {}
        
        for region, ind in regions: 
            Phy_seasonal[f'Phy_seasonal_{region}'] = np.array(Phy_seasonal[f'Phy_seasonal_{region}'])
            Phy_seasonal_colwise[f'Phy_seasonal_{region}'] = np.transpose(Phy_seasonal[f'Phy_seasonal_{region}'])
            Phy_seasonal_mean[f'Phy_seasonal_{region}'] = Phy_seasonal_colwise[f'Phy_seasonal_{region}'].mean(axis=1)

            # rearrange data for SO: 
            if region.startswith("SO"):
                Phy_seasonal_mean[f'Phy_seasonal_{region}'] = np.concatenate((Phy_seasonal_mean[f'Phy_seasonal_{region}'][6:], 
                                                                                        Phy_seasonal_mean[f'Phy_seasonal_{region}'][:6]))
        
        #sys.exit()                                    
        #______________________________________________________________
        # Diamtoms
        
        Dia_seasonal_colwise = {}
        Dia_seasonal_mean = {}
        
        for region, ind in regions: 
            Dia_seasonal[f'Dia_seasonal_{region}'] = np.array(Dia_seasonal[f'Dia_seasonal_{region}'])
            Dia_seasonal_colwise[f'Dia_seasonal_{region}'] = np.transpose(Dia_seasonal[f'Dia_seasonal_{region}'])
            Dia_seasonal_mean[f'Dia_seasonal_{region}'] = Dia_seasonal_colwise[f'Dia_seasonal_{region}'].mean(axis=1)

            # rearrange data for SO: 
            if region.startswith("SO"):
                Dia_seasonal_mean[f'Dia_seasonal_{region}'] = np.concatenate((Dia_seasonal_mean[f'Dia_seasonal_{region}'][6:], 
                                                                                        Dia_seasonal_mean[f'Dia_seasonal_{region}'][:6]))
                
        #______________________________________________________________
        # Coccolithophores

        Cocco_seasonal_colwise = {}
        Cocco_seasonal_mean = {}
        
        for region, ind in regions: 
            Cocco_seasonal[f'Cocco_seasonal_{region}'] = np.array(Cocco_seasonal[f'Cocco_seasonal_{region}'])
            Cocco_seasonal_colwise[f'Cocco_seasonal_{region}'] = np.transpose(Cocco_seasonal[f'Cocco_seasonal_{region}'])
            Cocco_seasonal_mean[f'Cocco_seasonal_{region}'] = Cocco_seasonal_colwise[f'Cocco_seasonal_{region}'].mean(axis=1)

            # rearrange data for SO: 
            if region.startswith("SO"):
                Cocco_seasonal_mean[f'Cocco_seasonal_{region}'] = np.concatenate((Cocco_seasonal_mean[f'Cocco_seasonal_{region}'][6:], 
                                                                                            Cocco_seasonal_mean[f'Cocco_seasonal_{region}'][:6]))
            

        #______________________________________________________________
        # Phaeocystis

        Phaeo_seasonal_colwise = {}
        Phaeo_seasonal_mean = {}
        
        for region, ind in regions: 
            Phaeo_seasonal[f'Phaeo_seasonal_{region}'] = np.array(Phaeo_seasonal[f'Phaeo_seasonal_{region}'])
            Phaeo_seasonal_colwise[f'Phaeo_seasonal_{region}'] = np.transpose(Phaeo_seasonal[f'Phaeo_seasonal_{region}'])
            Phaeo_seasonal_mean[f'Phaeo_seasonal_{region}'] = Phaeo_seasonal_colwise[f'Phaeo_seasonal_{region}'].mean(axis=1)

            # rearrange data for SO: 
            if region.startswith("SO"):
                Phaeo_seasonal_mean[f'Phaeo_seasonal_{region}'] = np.concatenate((Phaeo_seasonal_mean[f'Phaeo_seasonal_{region}'][6:], 
                                                                                            Phaeo_seasonal_mean[f'Phaeo_seasonal_{region}'][:6]))
          
        #================================
        # Plotting: 

        months_name = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
        months_name_SO = np.concatenate((months_name[6:], months_name[:6]))

        fig = plt.figure(figsize=(12,9), facecolor='w', edgecolor='k', tight_layout = True)

        # ARCTIC -----------------------------------------------
        
        plt.subplot(2,3,1)
        plt.plot(months_name, Phy_seasonal_mean[f'Phy_seasonal_Arc1'][:], label='SmallPhy')
        plt.plot(months_name, Dia_seasonal_mean[f'Dia_seasonal_Arc1'], label='Diatoms')
        if cocco_path.is_file():
            plt.plot(months_name, Cocco_seasonal_mean[f'Cocco_seasonal_Arc1'], label='Cocco')
        if phaeo_path.is_file():
            plt.plot(months_name, Phaeo_seasonal_mean[f'Phaeo_seasonal_Arc1'], label='Phaeo')

        plt.ylabel(ylabel)
        if self.type == 'Chl':
            plt.ylim(0,0.5)
        elif self.type == 'NPP':
            plt.ylim(0,35)
        plt.title('Seasonal cycle in the Arctic (Lat: 60-70)')

        plt.subplot(2,3,2)
        plt.plot(months_name, Phy_seasonal_mean[f'Phy_seasonal_Arc2'], label='SmallPhy')
        plt.plot(months_name, Dia_seasonal_mean[f'Dia_seasonal_Arc2'], label='Diatoms')
        if cocco_path.is_file():
            plt.plot(months_name, Cocco_seasonal_mean[f'Cocco_seasonal_Arc2'], label='Cocco')
        if phaeo_path.is_file():
            plt.plot(months_name, Phaeo_seasonal_mean[f'Phaeo_seasonal_Arc2'], label='Phaeo')

        plt.ylabel(ylabel)
        if self.type == 'Chl':
            plt.ylim(0,0.5)
        elif self.type == 'NPP':
            plt.ylim(0,35)
        plt.title('Seasonal cycle in the Arctic (Lat: 70-80)')

        plt.subplot(2,3,3)
        plt.plot(months_name, Phy_seasonal_mean[f'Phy_seasonal_Arc3'], label='SmallPhy')
        plt.plot(months_name, Dia_seasonal_mean[f'Dia_seasonal_Arc3'], label='Diatoms')
        if cocco_path.is_file():
            plt.plot(months_name, Cocco_seasonal_mean[f'Cocco_seasonal_Arc3'], label='Cocco')
        if phaeo_path.is_file():
            plt.plot(months_name, Phaeo_seasonal_mean[f'Phaeo_seasonal_Arc3'], label='Phaeo')

        plt.ylabel(ylabel)
        if self.type == 'Chl':
            plt.ylim(0,0.5)
        elif self.type == 'NPP':
            plt.ylim(0,35)
        plt.title('Seasonal cycle in the Arctic (Lat: 80-90)')

        # SO --------------------------------------------------
        
        plt.subplot(2,3,4)
        plt.plot(months_name_SO, Phy_seasonal_mean[f'Phy_seasonal_SO1'], label='SmallPhy')
        plt.plot(months_name_SO, Dia_seasonal_mean[f'Dia_seasonal_SO1'], label='Diatoms')
        if cocco_path.is_file():
            plt.plot(months_name_SO, Cocco_seasonal_mean[f'Cocco_seasonal_SO1'], label='Cocco')
        if phaeo_path.is_file():
            plt.plot(months_name_SO, Phaeo_seasonal_mean[f'Phaeo_seasonal_SO1'], label='Phaeo')

        plt.ylabel(ylabel)
        if self.type == 'Chl':
            plt.ylim(0,0.85)
        elif self.type == 'NPP':
            plt.ylim(0,37)
        plt.title('Seasonal cycle in the SO (Lat: 40-50)')

        plt.subplot(2,3,5)
        plt.plot(months_name_SO, Phy_seasonal_mean[f'Phy_seasonal_SO2'], label='SmallPhy')
        plt.plot(months_name_SO, Dia_seasonal_mean[f'Dia_seasonal_SO2'], label='Diatoms')
        if cocco_path.is_file():
            plt.plot(months_name_SO, Cocco_seasonal_mean[f'Cocco_seasonal_SO3'], label='Cocco')
        if phaeo_path.is_file():
            plt.plot(months_name_SO, Phaeo_seasonal_mean[f'Phaeo_seasonal_SO3'], label='Phaeo')

        plt.ylabel(ylabel)
        if self.type == 'Chl':
            plt.ylim(0,0.85)
        elif self.type == 'NPP':
            plt.ylim(0,37)
        plt.title('Seasonal cycle in the SO (Lat: 50-60)')

        plt.subplot(2,3,6)
        plt.plot(months_name_SO, Phy_seasonal_mean[f'Phy_seasonal_SO3'], label='SmallPhy')
        plt.plot(months_name_SO, Dia_seasonal_mean[f'Dia_seasonal_SO3'], label='Diatoms')
        if cocco_path.is_file():
            plt.plot(months_name_SO, Cocco_seasonal_mean[f'Cocco_seasonal_SO3'], label='Cocco')
        if phaeo_path.is_file():
            plt.plot(months_name_SO, Phaeo_seasonal_mean[f'Phaeo_seasonal_SO3'], label='Phaeo')

        plt.ylabel(ylabel)
        if self.type == 'Chl':
            plt.ylim(0,0.85)
        elif self.type == 'NPP':
            plt.ylim(0,37)
        plt.title('Seasonal cycle in the SO (Lat: 60-70)')

        labels=('SmallPhy','Diatoms','Coccos','Phaeo')

        fig.legend(labels, loc='lower center', bbox_to_anchor=(0.5, -0.05), ncol=6)
            
    
    def get_seansonal_data(self, dataset, index):
        
        import numpy as np
        import contextlib
        import io
        import pyfesom2 as pf

        seasonal_lat = []
        
        for month in self.months: 
            
            seasonalData = dataset[month][:][:]
            
            if self.type == 'Chl': 
                # suppress print statement from layermean fxn 
                with io.StringIO() as buf, contextlib.redirect_stdout(buf):
                    # sum up Chl-a over depth:
                    data_depthmean = pf.layermean_data(seasonalData, self.mesh)

                # separate Chl-a to given latitudes (via index) and take the mean over region:
                seasonal_lat.append(pf.areamean_data(data_depthmean, self.mesh, mask=index))

            elif self.type == 'NPP':
                seasonal_lat.append(pf.areamean_data(seasonalData, self.mesh, mask=index))
                                            

        return(seasonal_lat)