In [1]:
#Decoding scripts/ de-compressing scripts

import os
import numpy as np
import rans
from PIL import Image
import util
from torch_vae import tvae_utils
from torch_vae.tvae_beta_binomial import BetaBinomialVAE
import torch
import time

# Parameters
prior_precision = 8
obs_precision = 14
q_precision = 14
latent_dim = 50
latent_shape = (1, latent_dim)

# Load the VAE model
model = BetaBinomialVAE(hidden_dim=200, latent_dim=latent_dim)
model.load_state_dict(
    torch.load('saved_params/torch_vae_beta_binomial_params',
               map_location=lambda storage, location: storage))
model.eval()

rec_net = tvae_utils.torch_fun_to_numpy_fun(model.encode)
gen_net = tvae_utils.torch_fun_to_numpy_fun(model.decode)

obs_append = tvae_utils.beta_binomial_obs_append(255, obs_precision)
obs_pop = tvae_utils.beta_binomial_obs_pop(255, obs_precision)

vae_append = util.vae_append(latent_shape, gen_net, rec_net, obs_append,
                             prior_precision, q_precision)
vae_pop = util.vae_pop(latent_shape, gen_net, rec_net, obs_pop,
                       prior_precision, q_precision)

# Load the compressed message
compressed_lengths = np.loadtxt('compressed_image/compressed_lengths_cts').tolist()
compressed_message = np.fromfile('compressed_image/compressed_message.bin', dtype=np.uint32)


# Decode the images
state = rans.unflatten(compressed_message)
decode_start_time = time.time()
decoded_images = []

for n in range(len(compressed_lengths)):
    state, image_ = vae_pop(state)
    
    # Decompressed image
    image_array = image_.numpy().reshape(28, 28)  # Convert to NumPy and reshape for MNIST dimensions (28x28)
    image_save_path = f'decompressed_image/decompressed_image_{n}.png'
    Image.fromarray(image_array.astype(np.uint8)).save(image_save_path)
    
    decoded_images.append(image_array)
    print(f'Decoded image {n}')

print('\nAll decoded in {:.2f}s'.format(time.time() - decode_start_time))

# Save the recovered bits
recovered_bits = rans.flatten(state)
np.save('recovered_bits.npy', recovered_bits)


  torch.load('saved_params/torch_vae_beta_binomial_params',


Decoded image 0
Decoded image 1
Decoded image 2
Decoded image 3
Decoded image 4
Decoded image 5
Decoded image 6
Decoded image 7
Decoded image 8
Decoded image 9
Decoded image 10
Decoded image 11
Decoded image 12
Decoded image 13
Decoded image 14
Decoded image 15
Decoded image 16
Decoded image 17
Decoded image 18
Decoded image 19
Decoded image 20
Decoded image 21
Decoded image 22
Decoded image 23
Decoded image 24
Decoded image 25
Decoded image 26
Decoded image 27
Decoded image 28
Decoded image 29
Decoded image 30
Decoded image 31
Decoded image 32
Decoded image 33
Decoded image 34
Decoded image 35
Decoded image 36
Decoded image 37
Decoded image 38
Decoded image 39
Decoded image 40
Decoded image 41
Decoded image 42
Decoded image 43
Decoded image 44
Decoded image 45
Decoded image 46
Decoded image 47
Decoded image 48
Decoded image 49
Decoded image 50
Decoded image 51
Decoded image 52
Decoded image 53
Decoded image 54
Decoded image 55
Decoded image 56
Decoded image 57
Decoded image 58
Decoded