In [None]:
import numpy as np
import torch
from lib.models.models import UniformRate
import ml_collections
from config.config_train_sample import get_config
from lib.utils import utils
import torch.nn.functional as F
from lib.networks.networks_paul import UNet
"""
Now, when we minimize LCT, we are sampling (x, x ̃) from the forward process and then maximizing the assigned model probability for 

the pairing in the reverse direction, just as in LDT. The slight extra complexity comes from the fact we areconsidering the case 

when xk = xk+1 and the case when xk ̸= xk+1 separately. When xk = xk+1, this corresponds to the first term in LCT which we can see 

is minimizing the reverse rate out of x which is exactly maximizing the model probability for no transition to occur. When xk ̸= xk+1, 

this corresponds to the second term in LCT, which is maximizing the reverse rate from x ̃ to x which in turn maximizes the model probability 

for the x ̃ to x transition to occur.
"""


In [None]:
import time
S=256
rate_const = 1
cfg = get_config()
cfg.data.S = S
device = 'cpu'
cfg.model.rate_const = rate_const

S = 256
B = 64
D = 1024
model = UniformRate(cfg, 'cpu')
unet = UNet(
                in_channel=1,
                out_channel=1,
                channel=32,
                channel_multiplier=cfg.model.ch_mult,
                n_res_blocks=cfg.model.num_res_blocks,
                attn_resolutions=[16],
                num_heads=1,
                dropout=cfg.model.dropout,
                model_output = 'logits',  # 'logits' or 'logistic_pars'
                num_classes=S,
                x_min_max=(0, 255),
                img_size=32,
        )
t = 1
xt= torch.randint(low=0, high=S, size=(B, D), dtype=torch.int)
B, D = xt.shape
C, H, W = (1, 32, 32)
S = 256
x = xt.view(B, C, H, W)
x_pred = unet(x, t * torch.ones((B,), device=device))
x_pred = x_pred.view(B, D, S)
print(x_pred, x_pred.shape)
log_p0t = F.log_softmax(x_pred, dim=2)
print(log_p0t, log_p0t.shape)



In [30]:
qt_test = model.transition(t * torch.ones((B,)))
print(qt_test.shape)
qt_test = utils.expand_dims(qt_test, axis=list(range(1, xt.dim() - 1)))
print(qt_test.shape)
torch.where(qt_test <= 1e-35, -1e9, torch.log(qt_test))

torch.Size([64, 256, 256])
torch.Size([64, 256, 256])


tensor([[[-5.5452, -5.5452, -5.5452,  ..., -5.5452, -5.5452, -5.5452],
         [-5.5452, -5.5452, -5.5452,  ..., -5.5452, -5.5452, -5.5452],
         [-5.5452, -5.5452, -5.5452,  ..., -5.5452, -5.5452, -5.5452],
         ...,
         [-5.5452, -5.5452, -5.5452,  ..., -5.5452, -5.5452, -5.5452],
         [-5.5452, -5.5452, -5.5452,  ..., -5.5452, -5.5452, -5.5452],
         [-5.5452, -5.5452, -5.5452,  ..., -5.5452, -5.5452, -5.5452]],

        [[-5.5452, -5.5452, -5.5452,  ..., -5.5452, -5.5452, -5.5452],
         [-5.5452, -5.5452, -5.5452,  ..., -5.5452, -5.5452, -5.5452],
         [-5.5452, -5.5452, -5.5452,  ..., -5.5452, -5.5452, -5.5452],
         ...,
         [-5.5452, -5.5452, -5.5452,  ..., -5.5452, -5.5452, -5.5452],
         [-5.5452, -5.5452, -5.5452,  ..., -5.5452, -5.5452, -5.5452],
         [-5.5452, -5.5452, -5.5452,  ..., -5.5452, -5.5452, -5.5452]],

        [[-5.5452, -5.5452, -5.5452,  ..., -5.5452, -5.5452, -5.5452],
         [-5.5452, -5.5452, -5.5452,  ..., -5

: 

In [27]:
h = 0.01
start_opt = time.time()
t_eps = t - h #tau
q_teps_0 = model.transition(t_eps * torch.ones((B,), device=device)) # (N, S, S)
q_teps_0 = utils.expand_dims(q_teps_0, axis=list(range(1, xt.ndim)))


q_t_teps = model.transit_between(t_eps * torch.ones((B,), device=device), t * torch.ones((B,), device=device))  # (N, S, S

q_t_teps = q_t_teps.permute(0, 2, 1)
b = utils.expand_dims(torch.arange(xt.shape[0]), axis=list(range(1, xt.ndim)))
q_t_teps = q_t_teps[b, xt.long()].unsqueeze(-2)
print("q_teps_0", q_teps_0.shape)
print("q_t_teps", q_t_teps.shape)
qt0 = q_teps_0 * q_t_teps 
print(qt0.shape)

end_opt = time.time()



q_teps_0 torch.Size([64, 1, 256, 256])
q_t_teps torch.Size([64, 1024, 1, 256])
torch.Size([64, 1024, 256, 256])


In [28]:
a = torch.where(q_teps_0 <= 0.0, -1e9, torch.log(q_teps_0))
b = torch.where(q_t_teps <= 0.0, -1e9, torch.log(q_t_teps))
c = a + b

In [None]:
b = utils.expand_dims(torch.arange(xt.shape[0]), axis=list(range(1, xt.ndim)))
q_t_teps = q_t_teps[b, xt.long()].unsqueeze(-2)
qt0 = q_teps_0 * q_t_teps # 30-60sekunden

log_qt0 = torch.where(qt0 <= 0.0, -1e9, torch.log(qt0)) # 7min
start_opt = time.time()
log_p0t = log_p0t.unsqueeze(-1)
log_prob = torch.logsumexp(log_p0t + log_qt0, dim=-2)
end_opt = time.time()
print(end_opt - start_opt)


In [29]:
print(log_qt0, log_qt0.shape)
print(c, c.shape)
print(c == log_qt0)

tensor([[[[-11.1708, -11.1708, -11.1708,  ..., -11.1708, -11.1708, -11.1708],
          [-11.1708, -11.1708, -11.1708,  ..., -11.1708, -11.1708, -11.1708],
          [-11.1708, -11.1708, -11.1708,  ..., -11.1708, -11.1708, -11.1708],
          ...,
          [-11.1708, -11.1708, -11.1708,  ..., -11.1708, -11.1708, -11.1708],
          [-11.1708, -11.1708, -11.1708,  ..., -11.1708, -11.1708, -11.1708],
          [-11.1708, -11.1708, -11.1708,  ..., -11.1708, -11.1708, -11.1708]],

         [[-11.1708, -11.1708, -11.1708,  ..., -11.1708, -11.1708, -11.1708],
          [-11.1708, -11.1708, -11.1708,  ..., -11.1708, -11.1708, -11.1708],
          [-11.1708, -11.1708, -11.1708,  ..., -11.1708, -11.1708, -11.1708],
          ...,
          [-11.1708, -11.1708, -11.1708,  ..., -11.1708, -11.1708, -11.1708],
          [-11.1708, -11.1708, -11.1708,  ..., -11.1708, -11.1708, -11.1708],
          [-11.1708, -11.1708, -11.1708,  ..., -11.1708, -11.1708, -11.1708]],

         [[-11.1708, -11.1708,

In [None]:

log_p0t = log_p0t.unsqueeze(-1)
log_prob = torch.logsumexp(log_p0t + log_qt0, dim=-2)
# axis kein parameter? fehler hier
end_opt = time.time()
print("sampling operations time", end_opt - start_opt)
q_teps_0 = model.transition(t_eps * torch.ones((B,), device=device)) # (N, S, S)
print(q_teps_0, q_teps_0.shape)
q_teps_0 = utils.expand_dims(q_teps_0, axis=list(range(1, xt.ndim)))
print(q_teps_0, q_teps_0.shape)
q_t_teps = model.transit_between(t_eps * torch.ones((B,), device=device), t * torch.ones((B,), device=device))  # (N, S, S
print(q_t_teps, q_t_teps.shape)
q_t_teps = q_t_teps.permute(0, 2, 1)
print(q_t_teps, q_t_teps.shape)

In [None]:
b = utils.expand_dims(torch.arange(xt.shape[0]), axis=list(range(1, xt.ndim)))
print(b, b.shape)
q_t_teps = q_t_teps[b, xt].unsqueeze(-2)
print(q_t_teps, q_t_teps.shape)

In [None]:
#-----------Transition matrix q_t|0: x0 -> xt ---------------------
cfg = get_config()
cfg.data.S = S
cfg.model.rate_const = rate_const
uni = UniformRate(cfg, 'cpu')
ts = torch.rand((B,))
qt0 = uni.transition(ts)
x0= torch.randint(low=0, high=S, size=(B, D), dtype=torch.int)
#print(x0)
#print(qt0, qt0.shape)
qt0_rows_reg = qt0[
    torch.arange(B, device=device).repeat_interleave(
        D
    ),  # repeats every element 0 to B-1 D-times
    x0.flatten().long(),  # minibatch.flatten() => (B, D) => (B*D) (1D-Tensor)
    :,
]
print(qt0_rows_reg, qt0_rows_reg.shape)
b = utils.expand_dims(torch.arange(B), (tuple(range(1, x0.dim()))))
qt0_rows_reg2 = qt0[b, x0] #.view(-1, S)

logits = torch.where(qt0 <= 0.0, -1e9, torch.log(qt0_rows_reg2))


x_t_cat = torch.distributions.categorical.Categorical(qt0_rows_reg)
x_t = x_t_cat.sample().view(B, D)
print(x_t, x_t.shape)

In [None]:
#-------------- Transition rate: x_t -> x_tilde ------------------
rate = uni.rate(ts)
#print(rate, rate.shape) # B, S, S
rate_vals_square = rate[
        torch.arange(B, device=device).repeat_interleave(D), x_t.long().flatten(), :
    ]
#print(rate_vals_square, rate_vals_square.shape)
rate_vals_square[
        torch.arange(B * D, device=device), x_t.long().flatten()
    ] = 0.0 
print(rate_vals_square, rate_vals_square.shape)

rate_vals_square = rate_vals_square.view(B, D, S)
print(rate_vals_square, rate_vals_square.shape)

rate_vals_square_dimsum = torch.sum(rate_vals_square, dim=2).view(B, D)
print(rate_vals_square_dimsum, rate_vals_square_dimsum.shape)

square_dimcat = torch.distributions.categorical.Categorical(rate_vals_square_dimsum)

square_dims = square_dimcat.sample() # sampled where transition takes places in every row of B
print("Where transition", square_dims, square_dims.shape)

rate_new_val_probs = rate_vals_square[
    torch.arange(B, device=device), square_dims, :
]  # (B, S)
print(rate_new_val_probs, rate_new_val_probs.shape)

square_newvalcat = torch.distributions.categorical.Categorical(
    rate_new_val_probs
)

# Shape: (B,) mit Werten im Bereich [0, S)
square_newval_samples = (
    square_newvalcat.sample()
)
print("Transition value", square_newval_samples, square_newval_samples.shape)

x_tilde = x_t.clone()
        # noisy image 
x_tilde[torch.arange(B, device=device), square_dims] = square_newval_samples
print(x_t)
print(x_tilde)


In [None]:
#-----------ELBO-------------------
mask_reg = torch.ones((B, D, S), device=device)

mask_reg[
    torch.arange(B, device=device).repeat_interleave(D),
    torch.arange(D, device=device).repeat(B),
    x_tilde.long().flatten(),
] = 0.0
print(x_tilde)
print(mask_reg, mask_reg.shape)

In [None]:
qt0_numer_reg = qt0.view(B, S, S)
print(qt0_numer_reg , qt0_numer_reg.shape)
# q_{t|0} (x|x_0)
qt0_denom_reg = (
    qt0[
        torch.arange(B, device=device).repeat_interleave(D),
        :,
        x_tilde.long().flatten(),
    ].view(B, D, S)
    + 1e-6
)
#print(qt0_denom_reg, qt0_denom_reg.shape)

#print(rate, rate.shape)
rate_vals_reg = rate[
    torch.arange(B, device=device).repeat_interleave(D),
    :,
    x_tilde.long().flatten(),
].view(B, D, S)
print(rate_vals_reg, rate_vals_reg.shape)
print((mask_reg * rate_vals_reg))
reg_tmp = (mask_reg * rate_vals_reg) @ qt0_numer_reg.transpose(1, 2)
print(reg_tmp, reg_tmp.shape)

In [None]:
rate_const = 1
S = 3
B = 2
D = 4
cfg = get_config()
cfg.data.S = S
cfg.model.rate_const = rate_const
uni = UniformRate(cfg, 'cpu')
ts = torch.rand((B,))
xt= torch.randint(low=0, high=S, size=(B, D), dtype=torch.int)
print(xt)

qt0 = uni.transition(ts)

qt0_y2x = qt0.permute(0, 2, 1)
print(qt0, qt0.shape)
print(qt0_y2x, qt0_y2x.shape)
print(qt0 == qt0_y2x)

In [None]:
b = utils.expand_dims(
    torch.arange(xt.shape[0]), tuple(range(1, xt.dim()))
)
print(b, b.shape)
qt0_y2x = qt0_y2x[b, xt]
print(qt0_y2x, qt0_y2x.shape)

In [None]:
logits = qt0_y2x
log_p0t = F.log_softmax(logits, dim=-1)
print(log_p0t, log_p0t.shape)
log_qt0 = torch.where(qt0 <= 1e-35, -1e9, torch.log(qt0))
print(log_qt0, log_qt0.shape)
log_qt0 = utils.expand_dims(log_qt0, axis=list(range(1, xt.dim())))
print(log_qt0, log_qt0.shape)
log_p0t = log_p0t.unsqueeze(-1)
print(log_p0t, log_p0t.shape)
log_prob = torch.logsumexp(log_p0t + log_qt0, dim=-2)
print(log_prob, log_prob.shape)


In [None]:
xt_onehot = F.one_hot(xt.long(), S)
qt0 = uni.transition(ts)
p0t = F.softmax(logits, dim=-1)
print(p0t, p0t.shape)
qt0 = utils.expand_dims(qt0, axis=list(range(1, xt.dim() - 1)))
print(qt0, qt0.shape)
prob_all = p0t @ qt0
print(prob_all.shape)
log_prob = torch.log(prob_all + 1e-35)
print(log_prob, log_prob.shape)
log_xt = torch.sum(log_prob * xt_onehot, axis=-1)
print(log_xt, log_xt.shape)

In [None]:
qt0 = uni.transition(ts)
t_eps = ts - 0.1
q_t_teps = uni.transit_between(t_eps * torch.ones((B,), device=device), ts * torch.ones((B,), device=device))
print(q_t_teps, q_t_teps.shape, qt0.shape)
b = utils.expand_dims(torch.arange(B), (tuple(range(1, x0.dim()))))
qt0_rows_reg2 = qt0[b, x0]
print(qt0_rows_reg2, qt0_rows_reg2.shape)
logits = torch.where(qt0_rows_reg2  <= 0.0, -1e9, torch.log(qt0_rows_reg2))
print(logits, logits.shape)

x_t_cat = torch.distributions.categorical.Categorical(logits).sample()
print(x_t_cat,x_t_cat.shape)

In [None]:
ll_xt = xt #B, D
ll_all =  logits# B, D, S
loss = -(
    (S - 1) * ll_xt
    + torch.sum(utils.log1mexp(ll_all), dim=-1)
    - utils.log1mexp(ll_xt)
)
print(loss, loss.shape)
weight = torch.ones((B,), dtype=torch.float32)
weight = utils.expand_dims(weight, axis=list(range(1, loss.dim())))
print(weight, weight.shape)
loss = loss * weight
print(loss, loss.shape)
loss = torch.sum(loss) / xt.shape[0]
print(loss, loss.shape)

In [None]:
ll_xt = xt #B, D
ll_all =  logits
xt_onehot = F.one_hot(xt.long(), num_classes=S)
b = utils.expand_dims(torch.arange(xt.shape[0]), tuple(range(1, xt.dim())))
print(b, b.shape)
qt0_x2y = uni.transition(ts)
print(qt0_x2y, qt0_x2y.shape)
qt0_y2x = qt0_x2y.permute(0, 2, 1)
print(qt0_x2y, qt0_x2y.shape)
qt0_y2x = qt0_y2x[b, xt]
print(qt0_x2y, qt0_x2y.shape)
ll_xt = ll_xt.unsqueeze(-1)
print("ll", ll_xt, ll_xt.shape)
backwd = torch.exp(ll_all - ll_xt) * qt0_y2x
print(backwd , backwd.shape)


In [None]:
first_term = torch.sum(backwd * (1 - xt_onehot), dim=-1)
print(first_term , first_term.shape)
qt0_x2y = qt0_x2y[b, xt]
print(qt0_x2y, qt0_x2y.shape)
fwd = (ll_xt - ll_all) * qt0_x2y
print(fwd, fwd.shape)
second_term = torch.sum(fwd * (1 - xt_onehot), dim=-1)
print(second_term, second_term.shape)
loss = first_term - second_term
print(loss, loss.shape)

In [None]:
weight = torch.ones((B,), dtype=torch.float32)
weight = utils.expand_dims(weight, axis=list(range(1, loss.dim())))
print(weight, weight.shape)
loss = loss * weight
print(loss, loss.shape)
loss = torch.sum(loss) / xt.shape[0]
print(loss, loss.shape)

In [None]:
ts = np.concatenate((np.linspace(1.0, 1e-3, 1000), np.array([0])))
#save_ts = ts[np.linspace(0, len(ts)-2, num_intermediates, dtype=int)]

for idx, t in (enumerate(ts[0:-1])):
    print(t)