In [1]:
import sys
sys.path.append('../n-trees/')
sys.path.append('../utils/')
import numpy as np
import torch
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
import datasets
import loss
import models_tests
import training_loop
import utils
import fast_generator

if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device('cpu')
print(device)

cuda:0


# Training

In [2]:
# Create the dataset
# dataset = datasets.ForestDataset(10, 2, temp=1, maxiter=10000, size=100000)

In [3]:
# torch.save(dataset, '../datasets/dataset_15_2_1e5.pt')
# dataset = torch.load('../datasets/dataset_6_1_1e5.pt')
dataset = torch.load('../datasets/dataset_10_2_1e6.pt')
# dataset = torch.load('../datasets/dataset_15_2_1e5.pt')

In [4]:
dataset.conditional = True

In [5]:
# Create the loss function
# loss_fn = loss.EntropyLoss(P_factor=15)
loss_fn = loss.ConditionalEntropyLoss(P_factor=15)

In [6]:
# Create the model

model = models_tests.ConditionalEntropyPrecond(
    n=dataset.n,
    d=dataset.d,
    model =  models_tests.ConditionalAttention(dataset.n, 
                                        dataset.d,                          # Number of color channels at input.
                                        dropout             = 0.10,         # Dropout probability of intermediate activations.
                                        num_heads           = 16,            # Number of layers in the MLP.
                                        model_dimension     = 512,          # Hidden layer size.
                                        num_encoder_layers  = 6,
                                        embedding_type      = 'fourier', # Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++.
                                        embedding_channels  = 128,          # Number of channels in the timestep embedding.
                                    )
).to(device)

# model = models_tests.EntropyPrecond(
#     n=dataset.n,
#     d=dataset.d,
#     model =  models_tests.Attention(dataset.n, 
#                                     dataset.d,                          # Number of color channels at input.
#                                     dropout             = 0.10,         # Dropout probability of intermediate activations.
#                                     num_heads           = 16,            # Number of layers in the MLP.
#                                     model_dimension     = 512,          # Hidden layer size.
#                                     num_encoder_layers  = 6,
#                                     embedding_type      = 'fourier',    # Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++.
#                                     embedding_channels  = 128,          # Number of channels in the timestep embedding.
#                                     )
# ).to(device)

In [7]:
param_size = 0
for param in model.parameters():
    param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

size_all_mb = (param_size + buffer_size) / 1024**2
print('model size: {:.3f}MB'.format(size_all_mb))

model size: 73.438MB


In [8]:
# Create the optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=0.000001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.005)

In [9]:
dataset[0][0].shape

torch.Size([1, 10, 10])

In [10]:
model(dataset[0][0].repeat(4,1,1,1).to(device), torch.randn(4,10,10,10).to(device), torch.tensor([0.5, 0.5, 0.5, 0.5]).to(device)).to(device).shape

torch.Size([4, 1, 10, 10])

In [11]:
# Train
training_loop.simple_training_loop(
    run_dir             = '../experiments/test16_attention_conditional',      # Output directory.
    dataset             = dataset,      # Options for training set.
    network             = model,        # Options for model and preconditioning.
    loss                = loss_fn,      # Options for loss function.
    optimizer           = optimizer,    # Options for optimizer.
    seed                = 0,            # Global random seed.
    batch_size          = 1024,          # Total batch size for one training iteration.
    num_workers         = 16,           # Number of data loading workers.
    total_kimg          = 1024<<11,      # Training duration, measured in thousands of training images.
    device              = device,
    kimg_per_tick       = 1024<<1,          # How often to save the training state.
)

Loading dataset...




Constructing network...
Setting up optimizer...
Setting up logs...
Training for 2097152 kimg...




# Inference

In [None]:
# model = torch.load('../experiments/test01/training-state-100000.pt')['net'].to(device).eval()
# model = torch.load('../experiments/test05/training-state-100000.pt')['net'].to(device).eval()

# model = torch.load('../experiments/test12_attention/training-state-680000.pt')['net'].to(device).eval()
model = torch.load('../experiments/test12_attention/training-state-680000.pt')['net'].to(device).eval()

In [None]:
test_dataset = datasets.ForestDataset(10, 2, 1, 10000, 100)
# test_dataset = datasets.ForestDataset(6, 1, 1, 10000, 100)

In [None]:
test_dataset = torch.load('../datasets/dataset_10_2_1e6.pt')

In [None]:
data, prior = test_dataset[0]
data = data.to(device).to(torch.float32).reshape(1, data.shape[0], data.shape[1])
prior = prior.to(device).to(torch.float32).reshape(1, prior.shape[0], prior.shape[1], prior.shape[2])

rnd_normal = torch.randn([data.shape[0],1,1], device=data.device) * 1
p = rnd_normal.exp()
p = (p / 20).clamp(0, 1) / 2 # clamp flip probability to [0,0.5]

# beta = torch.distributions.beta.Beta(1.3, 4)
# p = beta.sample([data.shape[0],1,1]).to(data.device) / 2


weight = 1 / (2 * p) # weight for loss function for balancing preconditioning loss potentially not needed
y = data
n = torch.bernoulli(torch.ones_like(y) * (1 - p)) * 2 - 1 # noise equal to a bit flip occuring with probability p
n = n.to(torch.float32)
D_yn = model(y * n, prior, p)


print(p)

In [None]:
fig, ax = plt.subplots(1, 6, figsize=(25, 5))

ax[0].imshow(data[0].cpu().numpy())
ax[1].imshow(n[0].cpu().numpy())
ax[2].imshow((y * n)[0].cpu().numpy())
ax[3].imshow(D_yn[0].detach().cpu().numpy())
ax[4].imshow((D_yn[0].detach()-data[0]).cpu().numpy(), vmin=-0.5, vmax=0.5)
ax[5].imshow((torch.argmax(prior[0], dim=0).detach()).abs().cpu().numpy())

In [None]:
def entropic_sampler(net, latents, priors=None, num_steps=100, churn = 1, p_min = 0.001, rho=4):
    
    # Time step discretization.
    step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
    t_steps = (0.5 ** (1 / rho) + step_indices / (num_steps - 1) * (p_min ** (1 / rho) - 0.5 ** (1 / rho))) ** rho
    t_steps = torch.cat([torch.as_tensor(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
    
    if priors is not None:
        priors = priors.to(torch.float64).to(latents.device)
    
    x_next = latents.to(torch.float64)
    for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
        x_cur = x_next

        if churn > 0: # Removal of churn reduces this to Euler-Maruyama, a churn of 1 is approximately 1% noising at each step
            gamma = churn / num_steps if p_min <= t_cur <= 0.5 else 0
            t_hat = torch.as_tensor(t_cur + gamma * t_cur)
            p_churn = torch.bernoulli(torch.ones_like(latents) * (t_hat - t_cur)) * 2 - 1 # here probabilities follow a simple sum
            x_hat = x_cur * p_churn
        else:
            t_hat = t_cur
            x_hat = x_cur
            

        # Euler step.
        if priors is not None:
            denoised = net(x_hat, priors, t_hat).to(torch.float64)
        else:
            denoised = net(x_hat, t_hat).to(torch.float64)
        n = torch.bernoulli(((denoised * x_hat * (t_next - 1)) + 1) / 2) * 2 - 1 # flip probability computed with product
        x_next = x_hat * n

        yield utils.EasyDict(x=x_next, denoised=denoised)
    # return x_next

In [None]:
# d = 6
d = 10
latents = torch.bernoulli(torch.ones((1,d,d), device=device) * 0.5) * 2 - 1


In [None]:
prior = test_dataset[15][1].unsqueeze(0)

In [None]:
sampler = entropic_sampler(model, latents, priors=prior, num_steps=50, churn = 0.01, p_min = 0.0001, rho=4)

In [None]:
plt.imshow(next(sampler).denoised.detach().cpu().numpy()[0])

In [None]:
for i in range(50):
    forest = np.stack((((next(sampler).denoised.detach().cpu().numpy()[0]+1.3)/2).astype(np.int64), (torch.argmax(prior[0], dim=0).detach()).abs().cpu().numpy()))
fast_generator.plot_forest(forest)

In [None]:
# p sample weighting, log-normal
rnd_normal = torch.randn([100000]) * 1.2
p = rnd_normal.exp()
p = (p / 20).clamp(0, 1)
plt.hist(p, bins=500)
plt.show()

In [None]:
# p sample weighting, log-normal
rnd_normal = torch.randn([100000]) * 1.2
p = rnd_normal.exp()
p = (p / 15).clamp(0, 1) / 2 # clamp flip probability to [0,0.5]
plt.hist(p, bins=500)
plt.show()

In [None]:
beta = torch.distributions.beta.Beta(1.5, 10)
p = beta.sample([100000])
plt.hist(p, bins=500)
plt.show()

In [None]:
rnd_normal = torch.randn([100000]) * 1.2
p = rnd_normal.exp()
p = (p / 200).clamp(0, 1)
plt.hist(p, bins=500)
plt.show()
torch.mean(p)

In [None]:
beta = torch.distributions.beta.Beta(10, 2)
p = beta.sample([100000])
plt.hist(p, bins=500)
plt.show()

In [None]:
p = torch.rand(100000) / 2
plt.hist(p, bins=500)
plt.show()