In [None]:
import numpy as np
from scipy.stats import norm

class Species:
    """
    This class represents a species with a temperature-dependent abundance function.
    """
    def __init__(self, opt_temp, std_dev, scale=1):
        self.opt_temp = opt_temp
        self.std_dev = std_dev
        self.scale = scale

    def calculate_abundance(self, temp_array):
        """
        Calculates the temperature-dependent abundance of the species for a given temperature array.
        """
        abundance = self.scale * norm.pdf(temp_array, loc=self.opt_temp, scale=self.std_dev)
        return abundance

class Community:
    """
    This class represents a community of species with temperature-dependent abundance functions.
    """
    def __init__(self, species_list):
        self.species_list = species_list

    def cal_tot_abun(self, temp_array):
        """
        Calculates the total abundance of the community for a given temperature array.
        """
        tot_abundance = 0
        for species in self.species_list:
            tot_abundance += species.calculate_abundance(temp_array)
        return tot_abundance

    def cal_rel_abun(self, sp, temp_array):
        """
        Plots the relative abundance of each species in the community for a given temperature array.
        """
        if isinstance(sp, Species):
            abundance = sp.calculate_abundance(temp_array)
        elif isinstance(sp, list):
            abundance = 0
            for species in sp:
                abundance += species.calculate_abundance(temp_array)            
        total_abundance = self.cal_tot_abun(temp_array)
        rel_abundance = abundance / total_abundance
        return rel_abundance
    
temp_range = np.linspace(0, 30, 100)

sp1 = Species(0, 5, 100)
sp2 = Species(30, 5, 100)

community = Community([sp1, sp2])

## plot the relative abundance of each species in the community
import matplotlib.pyplot as plt
fig, axs = plt.subplots(1,2, figsize=(10,4), sharex=True)

for sp in community.species_list:
    axs[0].plot(temp_range, sp.calculate_abundance(temp_range))

for sp in community.species_list:
    axs[1].plot(temp_range, community.cal_rel_abun(sp, temp_range))

## add labels
axs[0].set_title("Absolute abundance")
axs[0].set_xlabel("Temperature")
axs[0].set_ylabel("Abundance")

axs[1].set_title("Relative abundance")
axs[1].set_xlabel("Temperature")
axs[1].set_ylabel("Abundance")

fig.savefig("../output/fig_r1.png", dpi=300, bbox_inches="tight")

In [None]:
sp1 = Species(15,7)
sp2 = Species(15,3,2)

community = Community([sp1, sp2])

fig, axs = plt.subplots(1,2, figsize=(10,4), sharex=True)

## left: absoltue TPCs
for sp in community.species_list:
    axs[0].plot(temp_range, sp.calculate_abundance(temp_range), label=f"Species {sp.opt_temp}")

for sp in community.species_list:
    axs[1].plot(temp_range, community.cal_rel_abun(sp, temp_range), label=f"Species {sp.opt_temp}")

In [None]:
sp1 = Species(3, 5)
sp2 = Species(5, 10)
sp3 = Species(15, 7, 1)
sp4 = Species(25, 5, 1)
sp5 = Species(30, 3, 1)

community = Community([sp1, sp2, sp3, sp4, sp5])

ecogroup1 = [sp1,sp2]
ecogroup2 = sp3
ecogroup3 = [sp4,sp5]

## calculate the relative abundance of each ecogroup in the community
rel_abundance1 = community.cal_rel_abun(ecogroup1, temp_range)
rel_abundance2 = community.cal_rel_abun(ecogroup2, temp_range)
rel_abundance3 = community.cal_rel_abun(ecogroup3, temp_range)

plt.plot(temp_range, rel_abundance1, label="Ecogroup 1", linestyle="--")
plt.plot(temp_range, rel_abundance2, label="Ecogroup 2", linestyle="--")
plt.plot(temp_range, rel_abundance3, label="Ecogroup 3", linestyle="--")

In [None]:
# ## apply to 2D temperature array
# import xarray as xr
# sst_pi = xr.open_dataset("/Users/yingrui/Science/lgm_foram_census/tidy/HadISST_PI.nc")
# ## flip sst_pi
# sst_pi = sst_pi['sst'][::-1, :]

# def convert360_180(_ds):
#     """
#     convert longitude from 0-360 to -180 -- 180 deg
#     """
#     # check if already 
#     attrs = _ds['lon'].attrs
#     if _ds['lon'].min() >= 0:
#         with xr.set_options(keep_attrs=True): 
#             _ds.coords['lon'] = (_ds['lon'] + 180) % 360 - 180
#         _ds = _ds.sortby('lon')
#     return _ds

# sst_lgm = xr.open_dataset("/Users/yingrui/Science/lgm_foram_census/tidy/Tierney2020_DA_ocn_regrid.nc")
# sst_lgm = convert360_180(sst_lgm['SSTLGM'])

In [None]:
# fig, axs = plt.subplots(1,3, figsize=(12,3), sharex=True)
# sp1 = Species(3, 5)
# sp2 = Species(5, 10)
# sp3 = Species(15, 7, 1)
# sp4 = Species(25, 5, 1)
# sp5 = Species(30, 3, 1)
# community = Community([sp1, sp2, sp3, sp4, sp5])

# data_pi = community.cal_rel_abun(ecogroup2, sst_pi)

# p0 = axs[0].pcolormesh(data_pi, vmax=0.8, cmap="Spectral_r")
# ## add colorbar
# fig.colorbar(p0, ax=axs[0], orientation='horizontal')

# sp1 = Species(3, 5)
# sp2 = Species(5, 10)
# sp3 = Species(15, 7, 1)
# sp4 = Species(25, 5, 1)
# sp5 = Species(30, 3, 1)

# community = Community([sp1, sp2, sp3, sp4, sp5])
# data_lgm = community.cal_rel_abun(ecogroup2, sst_lgm)
# p1 = axs[1].pcolormesh(data_lgm, vmax=0.8,  cmap="Spectral_r")
# fig.colorbar(p1, ax=axs[1], orientation='horizontal')

# data_diff = data_pi - data_lgm
# p2 = axs[2].pcolormesh(data_diff, cmap="coolwarm", norm = plt.Normalize(vmin=-0.3, vmax=0.3))
# fig.colorbar(p2, ax=axs[2], orientation='horizontal')

# #fig.colorbar(p1)
# #fig.colorbar(p2)

# axs[0].set_title("PI")
# axs[1].set_title("LGM")
# axs[2].set_title("PI - LGM (no acclimation)")