In [2]:
import torch
import torch.nn.functional as F

In [3]:
import torch
import torch.nn.functional as F

def UH_conv(x,UH,viewmode=1):
    # Ensure x and UH are on the same device
    UH = UH.to(x.device)

    mm= x.shape; nb=mm[0]
    m = UH.shape[-1]
    padd = m-1
    if viewmode==1:
        xx = x.view([1,nb,mm[-1]])
        w  = UH.view([nb,1,m])
        groups = nb

    y = F.conv1d(xx, torch.flip(w,[2]), groups=groups, padding=padd, stride=1)
    y=y[:,:,0:-padd]
    return y.view(mm)


def UH_gamma(a,b,lenF=10):
    m = a.shape
    aa = F.relu(a[0:lenF,:,:]).view([lenF, m[1],m[2]])+0.1  # minimum 0.1
    theta = F.relu(b[0:lenF,:,:]).view([lenF, m[1],m[2]])+0.5  # minimum 0.5
    t = torch.arange(0.5,lenF*1.0, step=1.0).view([lenF,1,1]).repeat([1,m[1],m[2]]).to(aa.device)
    denom = (torch.lgamma(aa).exp())*(theta**aa)
    mid= t**(aa-1)
    right=torch.exp(-t/theta)
    w = 1/denom*mid*right
    w = w/w.sum(dim=0, keepdim=True)  # scale to 1 for each UH

    return w

class HBVMul(torch.nn.Module):
    """Multi-component HBV model implemented in PyTorch by Dapeng Feng"""

    def __init__(self):
        """Initiate an HBV instance"""
        super(HBVMul, self).__init__()

    def forward(self, x, parameters, mu, muwts=None, rtwts=None, bufftime=0, outstate=False,
                routOpt=False, comprout=False, corrwts=None, pcorr=None):
        PRECS = 1e-5  # Keep the numerical calculation stable

        device = x.device  # Determine the device

        # Initialization for warm-up states
        if bufftime > 0:
            with torch.no_grad():
                xinit = x[0:bufftime, :, :]
                initmodel = HBVMul()
                Qsinit, SNOWPACK, MELTWATER, SM, SUZ, SLZ = initmodel(
                    xinit, parameters, mu, muwts, rtwts, bufftime=0, outstate=True,
                    routOpt=False, comprout=False, corrwts=corrwts, pcorr=pcorr)
        else:
            # Without warm-up (bufftime=0), initialize state variables with small positive values
            Ngrid = x.shape[1]
            SNOWPACK = (torch.zeros([Ngrid, mu], dtype=torch.float32) + 0.001).to(device)
            MELTWATER = (torch.zeros([Ngrid, mu], dtype=torch.float32) + 0.001).to(device)
            SM = (torch.zeros([Ngrid, mu], dtype=torch.float32) + 0.001).to(device)
            SUZ = (torch.zeros([Ngrid, mu], dtype=torch.float32) + 0.001).to(device)
            SLZ = (torch.zeros([Ngrid, mu], dtype=torch.float32) + 0.001).to(device)

        P = x[bufftime:, :, 0]
        Nstep, Ngrid = P.size()
        if pcorr is not None:
            parPCORR = pcorr[0] + corrwts[:, 0] * (pcorr[1] - pcorr[0])
            P = parPCORR.unsqueeze(0).repeat(Nstep, 1) * P

        Pm = P.unsqueeze(2).repeat(1, 1, mu)  # Precipitation
        T = x[bufftime:, :, 1]
        Tm = T.unsqueeze(2).repeat(1, 1, mu)  # Temperature
        ETpot = x[bufftime:, :, 2]
        ETpm = ETpot.unsqueeze(2).repeat(1, 1, mu)  # Potential ET

        # Scale the parameters to real values
        parascaLst = [
            [1, 6], [50, 1000], [0.05, 0.9], [0.01, 0.5],
            [0.001, 0.2], [0.2, 1], [0, 10], [0, 100],
            [-2.5, 2.5], [0.5, 10], [0, 0.1], [0, 0.2]
        ]  # HBV parameters
        routscaLst = [[0, 2.9], [0, 6.5]]  # Routing parameters

        # Dimension of each parameter is [Nbasin, Ncomponent]
        parBETA = parascaLst[0][0] + parameters[:, 0, :] * (parascaLst[0][1] - parascaLst[0][0])
        parFC = parascaLst[1][0] + parameters[:, 1, :] * (parascaLst[1][1] - parascaLst[1][0])
        parK0 = parascaLst[2][0] + parameters[:, 2, :] * (parascaLst[2][1] - parascaLst[2][0])
        parK1 = parascaLst[3][0] + parameters[:, 3, :] * (parascaLst[3][1] - parascaLst[3][0])
        parK2 = parascaLst[4][0] + parameters[:, 4, :] * (parascaLst[4][1] - parascaLst[4][0])
        parLP = parascaLst[5][0] + parameters[:, 5, :] * (parascaLst[5][1] - parascaLst[5][0])
        parPERC = parascaLst[6][0] + parameters[:, 6, :] * (parascaLst[6][1] - parascaLst[6][0])
        parUZL = parascaLst[7][0] + parameters[:, 7, :] * (parascaLst[7][1] - parascaLst[7][0])
        parTT = parascaLst[8][0] + parameters[:, 8, :] * (parascaLst[8][1] - parascaLst[8][0])
        parCFMAX = parascaLst[9][0] + parameters[:, 9, :] * (parascaLst[9][1] - parascaLst[9][0])
        parCFR = parascaLst[10][0] + parameters[:, 10, :] * (parascaLst[10][1] - parascaLst[10][0])
        parCWH = parascaLst[11][0] + parameters[:, 11, :] * (parascaLst[11][1] - parascaLst[11][0])

        # Initialize time series of model variables
        Qsimmu = (torch.zeros(Pm.size(), dtype=torch.float32) + 0.001).to(device)
        ETmu = (torch.zeros(Pm.size(), dtype=torch.float32) + 0.001).to(device)

        # Output the three simulated components of total Q
        Qsimmu0 = (torch.zeros(Pm.size(), dtype=torch.float32) + 0.001).to(device)
        Qsimmu1 = (torch.zeros(Pm.size(), dtype=torch.float32) + 0.001).to(device)
        Qsimmu2 = (torch.zeros(Pm.size(), dtype=torch.float32) + 0.001).to(device)

        for t in range(Nstep):
            # Separate precipitation into liquid and solid components
            PRECIP = Pm[t, :, :]
            RAIN = PRECIP * (Tm[t, :, :] >= parTT).float()
            SNOW = PRECIP * (Tm[t, :, :] < parTT).float()

            # Snow routine
            SNOWPACK = SNOWPACK + SNOW
            melt = parCFMAX * (Tm[t, :, :] - parTT)
            melt = torch.clamp(melt, min=0.0)
            melt = torch.min(melt, SNOWPACK)
            MELTWATER = MELTWATER + melt
            SNOWPACK = SNOWPACK - melt
            refreezing = parCFR * parCFMAX * (parTT - Tm[t, :, :])
            refreezing = torch.clamp(refreezing, min=0.0)
            refreezing = torch.min(refreezing, MELTWATER)
            SNOWPACK = SNOWPACK + refreezing
            MELTWATER = MELTWATER - refreezing
            tosoil = MELTWATER - (parCWH * SNOWPACK)
            tosoil = torch.clamp(tosoil, min=0.0)
            MELTWATER = MELTWATER - tosoil

            # Soil and evaporation
            soil_wetness = (SM / parFC) ** parBETA
            soil_wetness = torch.clamp(soil_wetness, min=0.0, max=1.0)
            recharge = (RAIN + tosoil) * soil_wetness

            SM = SM + RAIN + tosoil - recharge
            excess = SM - parFC
            excess = torch.clamp(excess, min=0.0)
            SM = SM - excess
            evapfactor = SM / (parLP * parFC)
            evapfactor = torch.clamp(evapfactor, min=0.0, max=1.0)
            ETact = ETpm[t, :, :] * evapfactor
            ETact = torch.min(SM, ETact)
            SM = torch.clamp(SM - ETact, min=PRECS)  # SM cannot be zero for gradient tracking
            ETmu[t, :, :] = ETact

            # Groundwater boxes
            SUZ = SUZ + recharge + excess
            PERC = torch.min(SUZ, parPERC)
            SUZ = SUZ - PERC
            Q0 = parK0 * torch.clamp(SUZ - parUZL, min=0.0)
            SUZ = SUZ - Q0
            Q1 = parK1 * SUZ
            SUZ = SUZ - Q1
            SLZ = SLZ + PERC
            Q2 = parK2 * SLZ
            SLZ = SLZ - Q2
            Qsimmu[t, :, :] = Q0 + Q1 + Q2

            # Save components
            Qsimmu0[t, :, :] = Q0
            Qsimmu1[t, :, :] = Q1
            Qsimmu2[t, :, :] = Q2

        Qsimave0 = Qsimmu0.mean(-1, keepdim=True)
        Qsimave1 = Qsimmu1.mean(-1, keepdim=True)
        Qsimave2 = Qsimmu2.mean(-1, keepdim=True)
        ETave = ETmu.mean(-1, keepdim=True)

        # Get the average simulated streamflow
        if muwts is None:
            Qsimave = Qsimmu.mean(-1)
        else:
            Qsimave = (Qsimmu * muwts).sum(-1)

        if routOpt:
            # Routing logic (not used in this example)
            pass
        else:
            # No routing, output the initial average simulations
            Qs = Qsimave.unsqueeze(-1)  # Add a dimension

        if outstate:
            return Qs, SNOWPACK, MELTWATER, SM, SUZ, SLZ
        else:
            Qall = torch.cat((Qs, Qsimave0, Qsimave1, Qsimave2, ETave), dim=-1)
            return Qall


In [4]:
# Sample data generation
Ntime = 365  # Number of time steps (days)
Nbasin = 1   # Number of basins or grid points
mu = 1       # Number of model components

# Generate sample daily precipitation (P) between 0 and 10 mm/day
P = torch.rand(Ntime, Nbasin) * 10  # Shape: [Ntime, Nbasin]

# Generate sample daily temperature (T) between -5 and 25 degrees Celsius
T = torch.rand(Ntime, Nbasin) * 30 - 5  # Shape: [Ntime, Nbasin]

# Generate sample daily potential evapotranspiration (PET) between 0 and 5 mm/day
PET = torch.rand(Ntime, Nbasin) * 5  # Shape: [Ntime, Nbasin]

# Combine P, T, and PET into input tensor x
x = torch.stack([P, T, PET], dim=2)  # Shape: [Ntime, Nbasin, 3]

# Determine the device (CPU or GPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
x = x.to(device)


In [5]:
x.shape

torch.Size([365, 1, 3])

In [6]:
# Number of parameters per component
Nparas = 12

# Initialize parameters to mid-range values (scaled between 0 and 1)
parameters = torch.full((Nbasin, Nparas, mu), 0.5, dtype=torch.float32).to(device)


In [7]:
# Instantiate the HBV model
model = HBVMul().to(device)

# Run the model
outputs = model(
    x, parameters, mu, muwts=None, rtwts=None, bufftime=0, outstate=False,
    routOpt=False, comprout=False, corrwts=None, pcorr=None
)


In [8]:
# Extract simulated streamflow (Qs)
Qs = outputs[:, :, 0]  # Shape: [Ntime, Nbasin]

# Move data to CPU and convert to numpy for printing
Qs_cpu = Qs.cpu().detach().numpy()

# Print the shape and first few values of Qs
print("Simulated streamflow (Qs) shape:", Qs_cpu.shape)
print("First 10 days of simulated streamflow (mm/day):")
print(Qs_cpu[:10, :])


Simulated streamflow (Qs) shape: (365, 1)
First 10 days of simulated streamflow (mm/day):
[[0.000201  ]
 [0.00018096]
 [0.0001637 ]
 [0.00015283]
 [0.00015448]
 [0.00016913]
 [0.00019023]
 [0.00021721]
 [0.00034316]
 [0.00043229]]
