In [None]:
import numpy as np
import torch
from lib.models.models import UniformRate
import ml_collections
from config.config_train_sample import get_config
from lib.utils import utils
import torch.nn.functional as F


In [None]:
"""
Now, when we minimize LCT, we are sampling (x, x ̃) from the forward process and then maximizing the assigned model probability for 

the pairing in the reverse direction, just as in LDT. The slight extra complexity comes from the fact we areconsidering the case 

when xk = xk+1 and the case when xk ̸= xk+1 separately. When xk = xk+1, this corresponds to the first term in LCT which we can see 

is minimizing the reverse rate out of x which is exactly maximizing the model probability for no transition to occur. When xk ̸= xk+1, 

this corresponds to the second term in LCT, which is maximizing the reverse rate from x ̃ to x which in turn maximizes the model probability 

for the x ̃ to x transition to occur.
"""

In [None]:
# Fragen Yannick:
# Wieso xt samplen from q_t|0 and then one transition to x_tilde => and then logits = net(x_tile, t)
#    # q_{t|0} (x ̃|x_0) =>  qt0_numer_reg = qt0.view(B, S, S)

# 2 forward passed: p^{θ}_{0|t}(x0|x) to calculate Rˆ{θ}_t(x, x′) and p^{θ}_{0|t}(x0|x ̃) to calculate Rˆ{θ}_t(x ̃, x). This is wasteful
rate_const = 1
S = 3
B = 2
D = 4
device = 'cpu'
rate = rate_const * np.ones((S, S))
rate = rate - np.diag(np.diag(rate))
#print(rate)
rate = rate - np.diag(np.sum(rate, axis=1))

#print(rate)
rate_matrix = torch.from_numpy(rate).float()

rate_matrix = torch.tile(rate_matrix.view(1, S, S), (B, 1, 1))
#print(rate_matrix.shape)
v = torch.tensor([1, 2, 3])
matrix = torch.diag_embed(v)
#print(matrix)

In [None]:
#-----------Transition matrix q_t|0: x0 -> xt ---------------------
cfg = get_config()
cfg.data.S = S
cfg.model.rate_const = rate_const
uni = UniformRate(cfg, 'cpu')
ts = torch.rand((B,))
qt0 = uni.transition(ts)
x0= torch.randint(low=0, high=S, size=(B, D), dtype=torch.int)
#print(x0)
#print(qt0, qt0.shape)
qt0_rows_reg = qt0[
    torch.arange(B, device=device).repeat_interleave(
        D
    ),  # repeats every element 0 to B-1 D-times
    x0.flatten().long(),  # minibatch.flatten() => (B, D) => (B*D) (1D-Tensor)
    :,
]
print(qt0_rows_reg, qt0_rows_reg.shape)
b = utils.expand_dims(torch.arange(B), (tuple(range(1, x0.dim()))))
qt0_rows_reg2 = qt0[b, x0] #.view(-1, S)

logits = torch.where(qt0 <= 0.0, -1e9, torch.log(qt0_rows_reg2))


x_t_cat = torch.distributions.categorical.Categorical(qt0_rows_reg)
x_t = x_t_cat.sample().view(B, D)
print(x_t, x_t.shape)

In [None]:
#-------------- Transition rate: x_t -> x_tilde ------------------
rate = uni.rate(ts)
#print(rate, rate.shape) # B, S, S
rate_vals_square = rate[
        torch.arange(B, device=device).repeat_interleave(D), x_t.long().flatten(), :
    ]
#print(rate_vals_square, rate_vals_square.shape)
rate_vals_square[
        torch.arange(B * D, device=device), x_t.long().flatten()
    ] = 0.0 
print(rate_vals_square, rate_vals_square.shape)

rate_vals_square = rate_vals_square.view(B, D, S)
print(rate_vals_square, rate_vals_square.shape)

rate_vals_square_dimsum = torch.sum(rate_vals_square, dim=2).view(B, D)
print(rate_vals_square_dimsum, rate_vals_square_dimsum.shape)

square_dimcat = torch.distributions.categorical.Categorical(rate_vals_square_dimsum)

square_dims = square_dimcat.sample() # sampled where transition takes places in every row of B
print("Where transition", square_dims, square_dims.shape)

rate_new_val_probs = rate_vals_square[
    torch.arange(B, device=device), square_dims, :
]  # (B, S)
print(rate_new_val_probs, rate_new_val_probs.shape)

square_newvalcat = torch.distributions.categorical.Categorical(
    rate_new_val_probs
)

# Shape: (B,) mit Werten im Bereich [0, S)
square_newval_samples = (
    square_newvalcat.sample()
)
print("Transition value", square_newval_samples, square_newval_samples.shape)

x_tilde = x_t.clone()
        # noisy image 
x_tilde[torch.arange(B, device=device), square_dims] = square_newval_samples
print(x_t)
print(x_tilde)


In [None]:
#-----------ELBO-------------------
mask_reg = torch.ones((B, D, S), device=device)

mask_reg[
    torch.arange(B, device=device).repeat_interleave(D),
    torch.arange(D, device=device).repeat(B),
    x_tilde.long().flatten(),
] = 0.0
print(x_tilde)
print(mask_reg, mask_reg.shape)

In [None]:
qt0_numer_reg = qt0.view(B, S, S)
print(qt0_numer_reg , qt0_numer_reg.shape)
# q_{t|0} (x|x_0)
qt0_denom_reg = (
    qt0[
        torch.arange(B, device=device).repeat_interleave(D),
        :,
        x_tilde.long().flatten(),
    ].view(B, D, S)
    + 1e-6
)
#print(qt0_denom_reg, qt0_denom_reg.shape)

#print(rate, rate.shape)
rate_vals_reg = rate[
    torch.arange(B, device=device).repeat_interleave(D),
    :,
    x_tilde.long().flatten(),
].view(B, D, S)
print(rate_vals_reg, rate_vals_reg.shape)
print((mask_reg * rate_vals_reg))
reg_tmp = (mask_reg * rate_vals_reg) @ qt0_numer_reg.transpose(1, 2)
print(reg_tmp, reg_tmp.shape)

In [None]:
rate_const = 1
S = 3
B = 2
D = 4
cfg = get_config()
cfg.data.S = S
cfg.model.rate_const = rate_const
uni = UniformRate(cfg, 'cpu')
ts = torch.rand((B,))
xt= torch.randint(low=0, high=S, size=(B, D), dtype=torch.int)
print(xt)

qt0 = uni.transition(ts)

qt0_y2x = qt0.permute(0, 2, 1)
print(qt0, qt0.shape)
print(qt0_y2x, qt0_y2x.shape)
print(qt0 == qt0_y2x)

In [None]:
b = utils.expand_dims(
    torch.arange(xt.shape[0]), tuple(range(1, xt.dim()))
)
print(b, b.shape)
qt0_y2x = qt0_y2x[b, xt]
print(qt0_y2x, qt0_y2x.shape)

In [None]:
logits = qt0_y2x
log_p0t = F.log_softmax(logits, dim=-1)
print(log_p0t, log_p0t.shape)
log_qt0 = torch.where(qt0 <= 1e-35, -1e9, torch.log(qt0))
print(log_qt0, log_qt0.shape)
log_qt0 = utils.expand_dims(log_qt0, axis=list(range(1, xt.dim())))
print(log_qt0, log_qt0.shape)
log_p0t = log_p0t.unsqueeze(-1)
print(log_p0t, log_p0t.shape)
log_prob = torch.logsumexp(log_p0t + log_qt0, dim=-2)
print(log_prob, log_prob.shape)


In [None]:
xt_onehot = F.one_hot(xt.long(), S)
qt0 = uni.transition(ts)
p0t = F.softmax(logits, dim=-1)
print(p0t, p0t.shape)
qt0 = utils.expand_dims(qt0, axis=list(range(1, xt.dim() - 1)))
print(qt0, qt0.shape)
prob_all = p0t @ qt0
print(prob_all.shape)
log_prob = torch.log(prob_all + 1e-35)
print(log_prob, log_prob.shape)
log_xt = torch.sum(log_prob * xt_onehot, axis=-1)
print(log_xt, log_xt.shape)

In [None]:
qt0 = uni.transition(ts)
b = utils.expand_dims(torch.arange(B), (tuple(range(1, x0.dim()))))
qt0_rows_reg2 = qt0[b, x0]
print(qt0_rows_reg2, qt0_rows_reg2.shape)
logits = torch.where(qt0_rows_reg2  <= 0.0, -1e9, torch.log(qt0_rows_reg2))
print(logits, logits.shape)

x_t_cat = torch.distributions.categorical.Categorical(logits).sample()
print(x_t_cat,x_t_cat.shape)

In [None]:
ll_xt = xt #B, D
ll_all =  logits# B, D, S
loss = -(
    (S - 1) * ll_xt
    + torch.sum(utils.log1mexp(ll_all), dim=-1)
    - utils.log1mexp(ll_xt)
)
print(loss, loss.shape)
weight = torch.ones((B,), dtype=torch.float32)
weight = utils.expand_dims(weight, axis=list(range(1, loss.dim())))
print(weight, weight.shape)
loss = loss * weight
print(loss, loss.shape)
loss = torch.sum(loss) / xt.shape[0]
print(loss, loss.shape)

In [None]:
ll_xt = xt #B, D
ll_all =  logits
xt_onehot = F.one_hot(xt.long(), num_classes=S)
b = utils.expand_dims(torch.arange(xt.shape[0]), tuple(range(1, xt.dim())))
print(b, b.shape)
qt0_x2y = uni.transition(ts)
print(qt0_x2y, qt0_x2y.shape)
qt0_y2x = qt0_x2y.permute(0, 2, 1)
print(qt0_x2y, qt0_x2y.shape)
qt0_y2x = qt0_y2x[b, xt]
print(qt0_x2y, qt0_x2y.shape)
ll_xt = ll_xt.unsqueeze(-1)
print("ll", ll_xt, ll_xt.shape)
backwd = torch.exp(ll_all - ll_xt) * qt0_y2x
print(backwd , backwd.shape)


In [None]:
first_term = torch.sum(backwd * (1 - xt_onehot), dim=-1)
print(first_term , first_term.shape)
qt0_x2y = qt0_x2y[b, xt]
print(qt0_x2y, qt0_x2y.shape)
fwd = (ll_xt - ll_all) * qt0_x2y
print(fwd, fwd.shape)
second_term = torch.sum(fwd * (1 - xt_onehot), dim=-1)
print(second_term, second_term.shape)
loss = first_term - second_term
print(loss, loss.shape)

In [None]:
weight = torch.ones((B,), dtype=torch.float32)
weight = utils.expand_dims(weight, axis=list(range(1, loss.dim())))
print(weight, weight.shape)
loss = loss * weight
print(loss, loss.shape)
loss = torch.sum(loss) / xt.shape[0]
print(loss, loss.shape)

In [61]:
ts = np.concatenate((np.linspace(1.0, 1e-3, 1000), np.array([0])))
#save_ts = ts[np.linspace(0, len(ts)-2, num_intermediates, dtype=int)]

for idx, t in (enumerate(ts[0:-1])):
    print(t)

1.0
0.999
0.998
0.997
0.996
0.995
0.994
0.993
0.992
0.991
0.99
0.989
0.988
0.987
0.986
0.985
0.984
0.983
0.982
0.981
0.98
0.979
0.978
0.977
0.976
0.975
0.974
0.973
0.972
0.971
0.97
0.969
0.968
0.967
0.966
0.965
0.964
0.963
0.962
0.961
0.96
0.959
0.958
0.957
0.956
0.955
0.954
0.953
0.952
0.951
0.95
0.949
0.948
0.947
0.946
0.945
0.944
0.943
0.942
0.941
0.94
0.9390000000000001
0.938
0.937
0.9359999999999999
0.935
0.9339999999999999
0.933
0.9319999999999999
0.931
0.9299999999999999
0.929
0.9279999999999999
0.927
0.926
0.925
0.924
0.923
0.922
0.921
0.92
0.919
0.918
0.917
0.916
0.915
0.914
0.913
0.912
0.911
0.91
0.909
0.908
0.907
0.906
0.905
0.904
0.903
0.902
0.901
0.9
0.899
0.898
0.897
0.896
0.895
0.894
0.893
0.892
0.891
0.89
0.889
0.888
0.887
0.886
0.885
0.884
0.883
0.882
0.881
0.88
0.879
0.878
0.877
0.876
0.875
0.874
0.873
0.872
0.871
0.87
0.869
0.868
0.867
0.866
0.865
0.864
0.863
0.862
0.861
0.86
0.859
0.858
0.857
0.856
0.855
0.854
0.853
0.852
0.851
0.85
0.849
0.848
0.847
0.846
0.845
0.8