In [20]:
"""
By Zhewen Hou
"""
import numpy as np
import torch
from torch import Tensor
from torch.distributions.normal import Normal
from scipy.stats import norm
from sklearn import *
import time
import sys

In [21]:
# numpy implementation
# def dgmm1d(x, mu, sigma, pi):
#    pdf_gmm = np.sum([pi[k] * norm.pdf(x, loc=mu[k], scale=sigma[k]) for k in range(len(mu))], axis=0)
#    return pdf_gmm

# density of 1D Gaussian Mixture
def dgmm1d(x: Tensor, mu: Tensor, sigma: Tensor, pi: Tensor) -> Tensor:
    pdf_gmm = torch.sum(torch.stack([pi[k] * Normal(mu[k], sigma[k]).log_prob(x).exp() for k in range(len(mu))]), dim=0)
    return pdf_gmm

In [22]:
# numpy implementation
# def pgmm1d(x, mu, sigma, pi):
#    cdf_gmm = np.sum([pi[k] * norm.cdf(x, loc=mu[k], scale=sigma[k]) for k in range(len(mu))], axis=0)
#    return cdf_gmm

# cdf of 1D Gaussian Mixture
def pgmm1d(x: Tensor, mu: Tensor, sigma: Tensor, pi: Tensor) -> Tensor:
    cdf_gmm = torch.sum(torch.stack([pi[k] * Normal(mu[k], sigma[k]).cdf(x) for k in range(len(mu))]), dim=0)
    return cdf_gmm

In [31]:
# numpy implementation: quantile of 1D Gaussian Mixture
#def qgmm1d(q, mu, sigma, pi):
#    ppf_full = np.array([norm.ppf(q, loc=mu[k], scale=sigma[k]) for k in range(len(mu))]).flatten()
#    ppf_full.sort()
#    cdf_gmm = np.sum([pi[k] * norm.cdf(ppf_full, loc=mu[k], scale=sigma[k]) for k in range(len(mu))], axis=0)
#    ## 1D linear interpolation
#    ppf_gmm = np.interp(q, cdf_gmm, ppf_full)
#    return ppf_gmm

# implementation of np.interp with pytorch
def interp(x: Tensor, xp: Tensor, fp: Tensor) -> Tensor:
    m = (fp[:, 1:] - fp[:, :-1]) / (xp[:, 1:] - xp[:, :-1])
    b = fp[:, :-1] - (m * xp[:, :-1])

    indicies = torch.sum(torch.ge(x[:, :, None], xp[:, None, :]), dim=2) - 1
    indicies = torch.clamp(indicies, 0, len(m[0]) - 1)

    return torch.gather(m, 1, indicies) * x + torch.gather(b, 1, indicies)


# quantile of 1D Gaussian Mixture
def qgmm1d(q: Tensor, mu: Tensor, sigma: Tensor, pi: Tensor) -> Tensor:
    #print((Normal(mu, sigma).icdf(q.view(q.size(0), 1, 1).expand(q.size(0), mu.size(0), mu.size(1)))).shape)
    ppf_full = torch.permute(Normal(mu, sigma).icdf(q.view(q.size(0), 1, 1).expand(q.size(0), mu.size(0), mu.size(1))),
                             (1, 0, 2)).flatten(1)
    # ppf_full = torch.permute(Normal(mu, sigma).icdf(q[:, None, None].expand(q.size(0), mu.size(0), mu.size(1))),
    #                         (1, 0, 2)).flatten(1)
    ppf_full= torch.sort(ppf_full, dim=1)[0]
    cdf_gmm = torch.sum(pi[:, None] * Normal(mu[:, None], sigma[:, None]).cdf(ppf_full[:, :, None]), dim=2)
    ppf_gmm = interp(q.expand(mu.size(0), -1), cdf_gmm, ppf_full)
    return ppf_gmm.squeeze()

In [32]:
def batch_dWasserstein(Y: Tensor, mu: Tensor, sigma: Tensor, pi: Tensor, q: Tensor, p: int = 2) -> Tensor:
    ppf_gmm = qgmm1d(q, mu, sigma, pi)
    #print(ppf_gmm.shape)
    #print(Y.shape)
    Wp = torch.mean((ppf_gmm - Y) ** p, dim=1) ** (1 / p)
    return Wp


# main functions for fitting Wasserstein mixture regression
def WDL(Y, q_vec, mu, sd, pi):
    """
    n: number of observations
    m: number of percentiles 
    k: number of components in the GMM 
    :param Y: n x m
    :param q_vec: m
    :param mu: n x k, mean of the components in the GMM
    :param sd: n x k, standard deviation of the components in the GMM
    :param pi: n x k, mixing proportion of the components in the GMM
    :return: mean distance between the empirical quantile function and the quantile function of the GMM
    """
    # calculate the mean of the Wasserstein distance for each distribution of different X
    w2 = torch.mean(batch_dWasserstein(Y, mu, sd, pi, q_vec) ** 2)
    return w2

#### Testing Below

In [33]:
'''
# set parameters
SEED = 234234
K = 2 # number of components in the GMM
n_dist = 300 ## number of distributions
omega = 0.1 # default 0.1
n_sample = 300

# simulate the data
np.random.seed(SEED)
torch.random.manual_seed(SEED)
X = np.random.uniform(size=(n_dist, 3)) * 2 - 1 # generate X
Y = np.zeros((n_dist, n_sample)) # generate Y


## simulate Y
for i in range(n_dist):
    mu_1 = X[i, 0]
    mu_2 = 2 * X[i, 1]**2 + 2
    mu_true = [mu_1, mu_2]
    #print('mu:', mu_true)
    sig_1 = np.abs(X[i, 1]) + 0.5
    sig_2 = np.abs(X[i, 0]) + 0.5
    sig_true = [sig_1, sig_2]
    #print('sig:', sig_true)
    pi_1 = 1 / (1 + np.exp(X[i, 2]))
    pi_true = [pi_1, 1-pi_1]
    #print('pi:', pi_true)
    ## simulate noise
    eps_noise = np.random.normal(loc=0, scale=omega, size=1)
    ## simulate responses
    var_gaussian = np.array([np.random.normal(loc=mu_true[k]+eps_noise,
                                              scale=sig_true[k],
                                              size=n_sample) for k in range(K)]).T
    var_mult = np.random.choice(range(K), size=n_sample, replace=True, p=pi_true)
    var_mult = np.eye(K)[var_mult]
    var_GMM = np.sum(var_mult * var_gaussian, axis=1)
    Y[i] = np.sort(var_GMM)
    
print(Y.shape)
print(Y)

# prepare for model fitting
K_mix = 2
n_iter = 1000
lr = 1e-1
v_lr = np.array([1] + [lr] * n_iter)
n_dist = Y.shape[0]
n_levs = 100
q_vec = torch.arange(1, n_levs, dtype=torch.float64) / n_levs # quantile levels
print('q_vec', q_vec)
## transform Y
Q_mat = np.array([np.quantile(Y[i], q_vec) for i in range(n_dist)]) # calculate the empirical quantile function
'''

(300, 300)
[[-3.09883251 -2.86643922 -2.51388866 ...  4.31112254  4.3739572
   5.20220006]
 [-2.91661178 -2.65928491 -2.44640414 ...  4.53591171  4.56868362
   4.62578348]
 [-2.4340326  -2.40057418 -2.28007935 ...  5.02492275  5.16380261
   5.78538448]
 ...
 [-2.46941656 -2.20831643 -2.10139391 ...  5.88877085  6.20330087
   6.4078245 ]
 [-2.84557891 -2.6724476  -2.48777096 ...  4.07628841  4.22882558
   4.26913273]
 [-3.17172267 -3.08769588 -3.08434324 ...  4.32627668  4.42752034
   4.526383  ]]
q_vec tensor([0.0100, 0.0200, 0.0300, 0.0400, 0.0500, 0.0600, 0.0700, 0.0800, 0.0900,
        0.1000, 0.1100, 0.1200, 0.1300, 0.1400, 0.1500, 0.1600, 0.1700, 0.1800,
        0.1900, 0.2000, 0.2100, 0.2200, 0.2300, 0.2400, 0.2500, 0.2600, 0.2700,
        0.2800, 0.2900, 0.3000, 0.3100, 0.3200, 0.3300, 0.3400, 0.3500, 0.3600,
        0.3700, 0.3800, 0.3900, 0.4000, 0.4100, 0.4200, 0.4300, 0.4400, 0.4500,
        0.4600, 0.4700, 0.4800, 0.4900, 0.5000, 0.5100, 0.5200, 0.5300, 0.5400,
        0.55

In [38]:
'''
# step 1. train-val split
n_train = int(0.8 * n_dist)
loc_train = np.random.choice(n_dist, n_train, replace=False)
loc_val = np.setdiff1d(np.arange(n_dist), loc_train)


X_train = torch.tensor(X[loc_train])
# generate true parameters
mu_true = torch.cat((X_train[:, 0].unsqueeze(dim=1), (2 * X_train[:, 1] ** 2 + 2).unsqueeze(dim=1)), dim=1)
sd_true = torch.cat(
    (torch.abs(X_train[:, 1]).unsqueeze(dim=1) + 0.5, torch.abs(X_train[:, 0]).unsqueeze(dim=1) + 0.5), dim=1)
alpha_true = torch.cat((torch.zeros_like(X_train[:, 2].unsqueeze(dim=1)), X_train[:, 2].unsqueeze(dim=1)), dim=1)
pi_true = torch.softmax(alpha_true, dim=1)

# generate features
X_train = torch.cat((X_train, torch.abs(X_train), X_train ** 2), 1)
y_train = torch.tensor(Q_mat[loc_train])
X_val = torch.tensor(X[loc_val])
X_val = torch.cat((X_val, torch.abs(X_val), X_val ** 2), 1)
y_val = torch.tensor(Q_mat[loc_val])


# test the speed of the loss function
start = time.time()

for i in range(100):
    alpha_matrix = torch.tensor(np.random.randn(9, 2), requires_grad=True)
    mu_matrix = torch.tensor(np.random.randn(9, 2), requires_grad=True)
    sd_matrix = torch.tensor(np.random.randn(9, 2), requires_grad=True)

    alpha = torch.matmul(X_train, alpha_matrix)
    pi = torch.softmax(alpha, dim=1)

    mu = torch.matmul(X_train, mu_matrix)
    
    sd = torch.abs(torch.matmul(X_train, sd_matrix))

    loss = WDL(y_train, q_vec, mu, sd, pi)

    print(loss)
print(time.time() - start)
'''

torch.Size([240, 2])
tensor(3.3191, dtype=torch.float64, grad_fn=<MeanBackward0>)
torch.Size([240, 2])
tensor(6.7576, dtype=torch.float64, grad_fn=<MeanBackward0>)
torch.Size([240, 2])
tensor(10.8323, dtype=torch.float64, grad_fn=<MeanBackward0>)
torch.Size([240, 2])
tensor(3.6517, dtype=torch.float64, grad_fn=<MeanBackward0>)
torch.Size([240, 2])
tensor(2.2748, dtype=torch.float64, grad_fn=<MeanBackward0>)
torch.Size([240, 2])
tensor(9.4206, dtype=torch.float64, grad_fn=<MeanBackward0>)
torch.Size([240, 2])
tensor(4.4102, dtype=torch.float64, grad_fn=<MeanBackward0>)
torch.Size([240, 2])
tensor(5.5286, dtype=torch.float64, grad_fn=<MeanBackward0>)
torch.Size([240, 2])
tensor(3.3920, dtype=torch.float64, grad_fn=<MeanBackward0>)
torch.Size([240, 2])
tensor(2.5275, dtype=torch.float64, grad_fn=<MeanBackward0>)
torch.Size([240, 2])
tensor(6.7134, dtype=torch.float64, grad_fn=<MeanBackward0>)
torch.Size([240, 2])
tensor(3.1449, dtype=torch.float64, grad_fn=<MeanBackward0>)
torch.Size([240

In [39]:
'''
# step 2. train the model
# initialize the optimizer
optimizer = torch.optim.SGD([alpha_matrix, mu_matrix, sd_matrix], lr=0.6)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=170, gamma=0.84)

# train the model
start = time.time()
for k in range(1000):
    """alpha = torch.matmul(X_train[:, 2:3], alpha_matrix)
    mu = torch.matmul(X_train[:, [0, 4, 6]], mu_matrix)
    sd = torch.abs(torch.matmul(torch.abs(X_train[:, [0, 1, 6]]), sd_matrix))"""

    alpha = torch.matmul(X_train, alpha_matrix)
    mu = torch.matmul(X_train, mu_matrix)
    sd = torch.abs(torch.matmul(torch.abs(X_train), sd_matrix))

    pi = torch.softmax(alpha, dim=1)

    optimizer.zero_grad()
    loss = WDL(y_train, q_vec, mu, sd, pi)
    print(k, loss.item())
    loss.backward()
    optimizer.step()
    scheduler.step()
end = time.time()
print(end - start)

print(mu)
'''

0 1.8953202843639845
1 1.3534818272009868
2 1.1289146756614765
3 0.9695439868641276
4 0.8410743849086068
5 0.7460565953840385
6 0.6636767057837374
7 0.5848967261742131
8 0.5266667713177535
9 0.4884040086465751
10 0.4627104085141768
11 0.4445376934490593
12 0.4312602267888029
13 0.42104361181653877
14 0.4129402232543243
15 0.4062451190122733
16 0.4004599705522235
17 0.39538086100784203
18 0.3908074570266714
19 0.3866144693195827
20 0.38270288903928823
21 0.379027041679309
22 0.3755465731476292
23 0.3722398365090432
24 0.3690815793859981
25 0.36605019903641384
26 0.3631314508306239
27 0.360317832741287
28 0.3576081756014649
29 0.35498246209752826
30 0.3524420986209344
31 0.3499822246531771
32 0.34761627252171307
33 0.34533445048140243
34 0.34314551886775635
35 0.34104293787801937
36 0.3390550776236632
37 0.33718703170860564
38 0.33544796533789367
39 0.3338419256028747
40 0.3324518168267894
41 0.3312934539459188
42 0.33045108734716455
43 0.329935288878456
44 0.32999522922057495
45 0.33071

357 0.229782807696876
358 0.22973927574679262
359 0.22969532287394812
360 0.22965191514094574
361 0.22960868005414736
362 0.22956475494083023
363 0.2295210815989688
364 0.22947797990538116
365 0.22943574721111146
366 0.22939352814827402
367 0.2293515024914417
368 0.22930978099982874
369 0.22926818707888863
370 0.22922623295609354
371 0.2291843900285459
372 0.2291427756019279
373 0.22910128174218394
374 0.22906004098120042
375 0.2290192402148177
376 0.22897839624326904
377 0.22893743902574137
378 0.22889648481200867
379 0.22885560519539255
380 0.22881504708271502
381 0.22877472960735712
382 0.22873457876907985
383 0.22869445473911915
384 0.2286544491594006
385 0.22861426596520606
386 0.22857374882523296
387 0.2285334578519921
388 0.22849328004210206
389 0.2284536787905523
390 0.22841372905912394
391 0.22837366785901717
392 0.22833344268351657
393 0.2282935780287301
394 0.22825399282671052
395 0.2282144914917676
396 0.22817536085882806
397 0.22813652017116812
398 0.22809742568554786
399 

704 0.21727433258976012
705 0.217243830528807
706 0.21721328847861296
707 0.217182694039943
708 0.217152193898954
709 0.21712178421433764
710 0.21709130430211407
711 0.21706083972909576
712 0.21703028522505444
713 0.21699983631903605
714 0.2169691645948607
715 0.21693844457732112
716 0.21690760832696887
717 0.21687669246277633
718 0.21684574423048394
719 0.21681483269126864
720 0.21678392413437755
721 0.21675310532729608
722 0.21672232702925734
723 0.21669131159030464
724 0.21666057232522012
725 0.21662957638775163
726 0.2165986180098112
727 0.21656802905653125
728 0.21653730720190842
729 0.21650644488418452
730 0.2164752089620812
731 0.2164438389158318
732 0.21641251897786293
733 0.2163814576374238
734 0.2163503619391959
735 0.21631913117259532
736 0.21628793248769487
737 0.21625669535489375
738 0.21622538363958915
739 0.21619384026412206
740 0.21616083557601007
741 0.2161272286383427
742 0.21609435347357778
743 0.21606118953875514
744 0.21602802774804297
745 0.21599539361971978
746 0

In [11]:
'''
# step 3. evaluate the model
alpha = torch.matmul(X_train, alpha_matrix)
mu = torch.matmul(X_train, mu_matrix)
sd = torch.abs(torch.matmul(torch.abs(X_train), sd_matrix))
pi = torch.softmax(alpha, dim=1)

# calculate the loss
loss = WDL(y_train, q_vec, mu, sd, pi)
# calculate the difference between estimators and real values
print((mu - mu_true).detach().numpy())
print((sd - sd_true).detach().numpy())
print((pi - pi_true).detach().numpy())
print(loss.item())
print()
'''

[[ 6.66676478e-01 -2.84067649e-01]
 [ 2.96840067e-01 -1.44553624e+00]
 [ 1.38995844e+00  5.06455550e-01]
 [ 1.12654493e+00  4.77331162e-01]
 [ 9.64291138e-01  1.58315217e-01]
 [ 1.81584360e-01  9.39345539e-01]
 [ 1.25845916e+00  2.35090734e-01]
 [ 5.91216707e-01  7.41017517e-01]
 [ 3.45591939e-01  1.25174015e+00]
 [ 2.51078522e-01 -5.05201003e-01]
 [-1.21620893e-01  8.20351780e-01]
 [ 7.35562548e-01  8.44545562e-01]
 [ 4.86298687e-01  1.71218464e+00]
 [ 4.67980513e-01  1.39678620e+00]
 [ 5.06183345e-01  8.41167563e-01]
 [ 2.36277541e-01  7.46228737e-01]
 [ 4.38339687e-01  1.17203483e-01]
 [ 6.32575823e-01  1.34099146e+00]
 [ 2.30677905e-01  1.51226974e-01]
 [ 9.95972070e-01  1.62690942e-01]
 [ 7.01800142e-01 -8.03170859e-02]
 [ 6.60889340e-01 -5.47164647e-01]
 [ 1.14072911e+00  8.68564505e-01]
 [ 1.01782373e+00  7.42340494e-01]
 [ 4.54636171e-01  8.62823627e-01]
 [ 4.08696982e-01  6.08927207e-01]
 [ 1.08478362e+00  5.49710235e-01]
 [ 1.07833246e+00  2.42432466e-01]
 [ 6.86554632e-01  6

In [10]:
# I forgot its purpose
"""optimizer.param_groups[0]['lr'] = 0.5
for k in range(500):
    alpha = torch.matmul(X_train[:, 2:3], alpha_matrix)
    mu = torch.matmul(X_train[:, [0, 4, 6]], mu_matrix)
    sd = torch.abs(torch.matmul(torch.abs(X_train[:, [0, 1, 6]]), sd_matrix))

    alpha = torch.matmul(X_train, alpha_matrix)
    mu = torch.matmul(X_train, mu_matrix)
    sd = torch.abs(torch.matmul(torch.abs(X_train), sd_matrix))

    pi = torch.softmax(alpha, dim=1)

    optimizer.zero_grad()
    loss = WDL(y_train, q_vec, mu, sd, pi)
    print(k, loss.item())
    loss.backward()
    optimizer.step()
    scheduler.step()"""


"""alpha = torch.matmul(X_train[:, 2:3], alpha_matrix)
mu = torch.matmul(X_train[:, [0, 4, 6]], mu_matrix)
sd = torch.abs(torch.matmul(torch.abs(X_train[:, [0, 1, 6]]), sd_matrix))"""
# WDL_example.txt
# Displaying WDL_example.txt.

SyntaxError: invalid syntax (1710302979.py, line 26)