In [1]:
import torch
import numpy as np
from torch.distributions import Exponential
import matplotlib.pyplot as plt
Bernoulli = torch.distributions.relaxed_bernoulli.RelaxedBernoulli
import pandas as pd


In [529]:
def sample_poisson_relaxed(lmbd, num_samples=100, temperature = 1e-2):
    z = Exponential(lmbd).rsample([num_samples])
    t = torch.cumsum(z,0)
    relaxed_indicator = torch.sigmoid((1.0 - t) / temperature)
    N = relaxed_indicator.sum(0)
    return N.round()

class CohortMat:
    def __init__(self, dbh=None, nTree=None, Species=None):
        self.nTree = np.random.poisson(10, [10, 1]) if nTree is None else nTree
        self.Species = np.random.randint(5, size=[10]) if Species is None else Species
        self.dbh = np.random.uniform(10, 90, size=[10,1]) if dbh is None else dbh
        
        if not torch.is_tensor(self.nTree):
            self.nTree = torch.tensor(self.nTree, dtype=torch.float32)
        if not torch.is_tensor(self.Species):
            self.Species = torch.tensor(self.Species, dtype=torch.int32)
        if not torch.is_tensor(self.dbh):
            self.dbh = torch.tensor(self.dbh, dtype=torch.float32)

def height(cohortMat, par):
    return (cohortMat.dbh*par.reshape([-1,1])*0.03).exp()

def BA(cohortMat):
    return torch.pi*(cohortMat.dbh/100./2.)**2

def compF(cohortMat, h, parGlobal, minLight = 50.):
    ba = BA(cohortMat)/0.1
    cohortHeights = height(cohortMat, parGlobal[cohortMat.Species,[0]])
    #BA_height = torch.cat([(lambda x: ba[cohortHeights > x].sum().reshape([1]))(x) for x in h.reshape([-1])])
    #print(BA_height)
    BA_height = torch.cat([(lambda x: (ba*torch.sigmoid((cohortHeights -x-0.1)/1e-8)).sum().reshape([1]))(x) for x in h])
    #print(BA_height1)
    AL = 1.-BA_height/minLight
    AL = torch.clamp(AL, min = 0)
    return AL
    

def growthF(cohortMat, timestep, parGrowth, parGlobal, envM):
    # Shade
    AL = compF(cohortMat, h=height(cohortMat, parGlobal[cohortMat.Species,[0]]), parGlobal = parGlobal)
    shade = fEnv3(AL.reshape([1,-1]), parGrowth[cohortMat.Species,[0]].reshape([-1,1]))
    # Environment
    environment = fEnv1(envM[timestep,0:3], parGrowth[cohortMat.Species,1:3].t())
    growth = ((1.- torch.pow(1.- environment+shade,4.0)) * parGrowth[cohortMat.Species, [3]]).reshape([-1,1])
    #growth = torch.cat([growth, torch.zeros_like(growth)], 1).max(1).values.reshape([-1,1])
    return torch.nn.functional.softplus(growth)

def fEnv1(env, par):
    #maxDens = torch.distributions.Normal(par, 0.1).log_prob(par).exp()
    #pDens = torch.distributions.Normal(par, 0.1).log_prob(env).exp()
    #return (pDens/maxDens).sum(1)
    return torch.sigmoid(env.matmul(par))
    #envCond = env*par
    #return envCond.sum(1)

def fEnv3(env, par):

   # minP = 1.-torch.distributions.Normal(par, 0.1).cdf(torch.zeros_like(par))
    #envCond = 1.-torch.distributions.Normal(par, 0.1).cdf(env)
    #envCond = (minP-envCond)/minP
    #return envCond.sum(1)
    return torch.sigmoid((pow(env,2.0)).matmul(par))

def fEnv3dot(env, par):

   # minP = 1.-torch.distributions.Normal(par, 0.1).cdf(torch.zeros_like(par))
    #envCond = 1.-torch.distributions.Normal(par, 0.1).cdf(env)
    #envCond = (minP-envCond)/minP
    #return envCond.sum(1)
    return torch.sigmoid((pow(env,2.0))*(par))


def mortF(cohortMat, timestep, parMort,parGlobal, envM):
    AL = compF(cohortMat, h=height(cohortMat, parGlobal[cohortMat.Species,[0]]), parGlobal=parGlobal)
    Shade = 1-fEnv3dot(AL.reshape([-1,1]), parMort[cohortMat.Species, [0]].reshape([-1,1])).reshape([-1,1])
    environment = (1-fEnv3(envM[timestep,...], parMort[cohortMat.Species, 1:3].t())).reshape([-1,1])
    gPSize = 0.1*(torch.clamp(cohortMat.dbh/parMort[cohortMat.Species,3].reshape([-1,1]), min = 0.00001) ).pow(2.3).reshape([-1,1])
    predM = (Shade+environment+gPSize).reshape([-1,1])
    mortP = torch.clamp(predM + torch.distributions.Uniform(-0.2,0.2).rsample([predM.shape[0],1]), 0, 1)#(torch.cat([predM, torch.zeros_like(predM)], 1).min(1).values).reshape([-1,1])   
    mort = mortP*cohortMat.nTree
    mort = mort+ mort.round().detach() - mort.detach()
    return mort
    #mort = torch.cat([(lambda i: Bernoulli(torch.tensor(1e-5), mortP[i]).rsample(cohortMat.nTree[i,...].int()).sum().reshape([1]))(i) for i in range(len(mortP)) ])
    #return mort.round().reshape([-1,1])

def regF(cohortMat, timestep, parReg,parGlobal, envM):
    AL = compF(cohortMat, h = torch.zeros([1], dtype=torch.float32), parGlobal=parGlobal)
    regP = torch.sigmoid((parReg[...,[0]]-AL)/1e-2)
    environment = fEnv1(envM[timestep,...].reshape([1,-1]), parReg[...,1:3].t()).reshape([-1,1])
    bounds = 0.6*(regP+environment)
    regeneration = (regP+environment+ torch.distributions.Uniform(-bounds,bounds).rsample([1])).exp()*0.9 -1.
    regeneration = regeneration+ regeneration.round().detach() - regeneration.detach() 
    return regeneration.reshape([-1,1])


# parReg = torch.randn([5, 3], requires_grad=True, dtype=torch.float32)
# parGlobal = torch.rand([5, 1], requires_grad=True, dtype=torch.float32)
# MortNP = np.concatenate([np.random.uniform(-1, 1, size = [5, 3]), np.random.uniform(200, 350, size = [5, 1])], 1 )
# 
# parMort = torch.tensor(MortNP, requires_grad=True, dtype=torch.float32)
# parGrowth = torch.tensor(np.random.uniform(-0.5, 0.5, size = [5, 4]), requires_grad=True, dtype=torch.float32)
# opt = torch.optim.Adam(params=[parGrowth, parMort], lr=10000000.1)
# 
# cohorts = CohortMat()
# opt.zero_grad()
# timestep = 1
# envM = envT
# 
# g = growthF(cohorts, timestep, parGrowth.float(), parGlobal, envM)
# cohorts.dbh=cohorts.dbh+g
# m = mortF(cohorts, timestep, parMort, parGlobal, envM)
# cohorts.nTree = torch.clamp(cohorts.nTree - m, min = 0)
#         #keep.append([cohorts.Species, cohorts.nTree, cohorts.dbh])
# 
# indices = (cohorts.nTree > 0).reshape([-1])
# r = regF(cohorts, timestep, parReg, parGlobal, envM)
# #m = torch.clamp(m + torch.distributions.Uniform(-0.1,0.1).rsample([10,1]), 0 , 1)*cohorts.nTree
# #cohorts.nTree = torch.clamp(cohorts.nTree - m, min = 0)
# #cohorts.nTree = torch.clamp(cohorts.nTree - m, min = 0)
# ll = torch.nn.functional.mse_loss(r, torch.zeros_like(r, dtype=torch.float32))
# ll.backward(retain_graph=True)
# print(parReg.grad)
# print(r)


tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 1.4763e+00, 1.5473e+00],
        [0.0000e+00, 1.2114e+01, 1.2696e+01],
        [0.0000e+00, 0.0000e+00, 0.0000e+00],
        [8.5811e-31, 1.2195e-01, 1.2781e-01]])
tensor([[ 0.0000],
        [ 4.0000],
        [11.0000],
        [ 0.0000],
        [ 1.0000]], grad_fn=<ViewBackward0>)


In [528]:
regF(cohorts, timestep, parReg, parGlobal, envM)

tensor([[0.],
        [0.],
        [3.],
        [0.],
        [1.]], grad_fn=<ViewBackward0>)

In [3]:
envM = torch.tensor(np.random.normal(0.0, 1, size=[100,2]))
cohorts = CohortMat()

parReg = torch.randn([5, 3], requires_grad=True)
parGlobal = torch.rand([5, 1], requires_grad=True)
MortNP = np.concatenate([np.random.uniform(0.0, 1.0, size = [5, 3]), np.random.uniform(200, 350, size = [5, 1])], 1 )

parMort = torch.tensor(MortNP, requires_grad=True)
parGrowth = torch.tensor(np.random.uniform(0, 8, size = [5, 4]), requires_grad=True)


In [530]:

def run_model(cohorts, parReg, parGrowth, parMort, parGlobal, envM, time = 100):

    Result = torch.zeros([time, 5])
    #cohorts = CohortMat()
    #keep = []
    for timestep in range(time):
        ba = BA(cohorts)
        ind = torch.unique(cohorts.Species)
        if len(ind) > 0: 
            Result[timestep, ind] += torch.cat([(lambda x: ba[cohorts.Species == x].sum().reshape([1]))(x) for x in ind], 0)

        g = growthF(cohorts, timestep, parGrowth, parGlobal, envM)
        cohorts.dbh=cohorts.dbh+g
        m = mortF(cohorts, timestep, parMort, parGlobal, envM)
        cohorts.nTree = torch.clamp(cohorts.nTree - m, min = 0)
        #keep.append([cohorts.Species, cohorts.nTree, cohorts.dbh])

        indices = (cohorts.nTree > 0).reshape([-1])
        r = regF(cohorts, timestep, parReg, parGlobal, envM)

        if r.sum() > 0:
            indices_new = r.reshape([-1]) > 0
            cohorts = CohortMat(nTree=torch.cat([cohorts.nTree[indices,...].reshape([-1,1]), r[indices_new,...].reshape([-1,1])], 0) , 
                                    Species=torch.cat([cohorts.Species[indices], torch.tensor(np.arange(0, 5, 1), dtype=torch.int32)[indices_new]],0), 
                                    dbh=torch.cat([cohorts.dbh[indices,...].reshape([-1,1]), torch.ones([5])[indices_new].reshape([-1,1])], 0))
        else:
            cohorts = CohortMat(nTree=cohorts.nTree[indices,...].reshape([-1,1]), 
                                    Species=cohorts.Species[indices], 
                                    dbh=cohorts.dbh[indices,...].reshape([-1,1]))
    
    return Result



In [564]:
ENV = np.loadtxt("../evaluation-env.csv", delimiter=",", dtype=float, skiprows=1)
obs = pd.read_csv("../response.csv")
response = obs.to_numpy()
envT = torch.tensor(ENV, dtype=torch.float32)
respT = torch.tensor(response, dtype=torch.float32)


parReg = torch.randn([5, 3], requires_grad=True, dtype=torch.float32)
parGlobal = torch.rand([5, 1], requires_grad=True, dtype=torch.float32)
MortNP = np.concatenate([np.random.uniform(0.0, 1.0, size = [5, 3]), np.random.uniform(200, 350, size = [5, 1])], 1 )

parMort = torch.tensor(MortNP, requires_grad=True, dtype=torch.float32)
parGrowth = torch.tensor(np.random.uniform(0, 8, size = [5, 4]), requires_grad=True, dtype=torch.float32)
opt = torch.optim.Adam(params=[parGrowth, parMort, parGlobal, parReg], lr=0.1)



In [565]:
for e in range(100):
    opt.zero_grad()
    losses = [torch.nn.functional.mse_loss(run_model(CohortMat(dbh = np.ones([5,1] ), Species = np.arange(0, 5, 1), nTree=np.ones([5,1])),parReg, parGrowth, parMort, parGlobal, envT, time = 200),respT).reshape([1]) for _ in range(10)]
    
    loss = torch.cat(losses).mean().float()
    loss.backward()
    opt.step()
    print(loss.item())
    

1328.486083984375
1322.5018310546875
1317.565673828125
1315.9718017578125
1305.744384765625
1306.151123046875
1303.349609375
1297.812255859375
1288.832275390625
1287.3173828125
1267.5047607421875
1264.754638671875
1271.888427734375
1234.842041015625
1270.4571533203125
1223.2869873046875
1191.0506591796875
1197.0771484375
1171.681884765625
1144.025146484375
1146.7667236328125
1117.939208984375
1091.891845703125
1087.8377685546875
1121.3739013671875


In [None]:
#unimodal_function <- function(x, c, w) {
#  return(pmax(0, 1 - ((x - c) / w)^2))
#}

