In [1]:
import numpy as np
import util
import rans

In [156]:
import os
import sys
sys.path.append(os.path.abspath('../pytorch-wavenet'))
import warnings
warnings.filterwarnings("ignore", category=UserWarning) 
warnings.filterwarnings("ignore", category=FutureWarning) 

import time
import argparse
import json
import math

from wavenet_vocoder import WaveNet
from audio_data import AudioDataset

import torch
import torch.optim as optim
import torch.utils.data
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter

# from util import parameter_count

from apex import amp
import mpld3
import matplotlib.pyplot as plt

In [3]:
prior_precision = 8
obs_precision = 14
q_precision = 14

rng = np.random.RandomState(0)
np.seterr(over='raise');

# Encoding and decoding with uniform distribution

In [4]:
range_exp = 8

other_bits = rng.randint(1 << range_exp, size=2, dtype=np.uint32)
print("other_bits", other_bits)

state = rans.x_init
print("state 1", state)
state = util.uniforms_append(range_exp)(state, other_bits)
print("state 2", state)

state, recovered_bits = util.uniforms_pop(range_exp, other_bits.shape[0])(state)
print("state", state)
print("recovered_bits", recovered_bits)

other_bits [172  47]
state 1 (2147483648, ())
state 2 (140737488367532, ())
state (2147483648, ())
recovered_bits [172  47]


# Encoding and decoding with categorical distribution

```python
def categoricals_append(probs, precision):
    """Assume that the last dim of probs contains the probability vectors,
    i.e. np.sum(probs, axis=-1) == ones"""
    # Flatten all but last dim of probs
    probs = np.reshape(probs, (-1, np.shape(probs)[-1]))
    cdfs = [categorical_cdf(p, precision) for p in probs]
    def append(state, data):
        data = np.ravel(data)
        return non_uniforms_append(precision, cdfs)(state, data)
    return append

def categoricals_pop(probs, precision):
    """Assume that the last dim of probs contains the probability vectors,
    i.e. np.sum(probs, axis=-1) == ones"""
    # Flatten all but last dim of probs
    data_shape = np.shape(probs)[:-1]
    probs = np.reshape(probs, (-1, np.shape(probs)[-1]))
    cdfs = [categorical_cdf(p, precision) for p in probs]
    ppfs = [categorical_ppf(p, precision) for p in probs]

    def pop(state):
        state, symbols = non_uniforms_pop(precision, ppfs, cdfs)(state)
        return state, np.reshape(symbols, data_shape)
    return pop
```

In [215]:
# range_exp = 3
precision = 14
# probs = np.array([[1/32, 31/32], [1/4, 3/4]])
probs = F.softmax(torch.rand(3, 8), dim=-1).numpy()

data = np.array([1, 2, 3])# rng.randint(1 << range_exp, size=2, dtype=np.uint32)
print("data", data)

data [1 2 3]


In [254]:
state = rans.x_init
print("state 1", state)
state = util.categoricals_append(probs, precision)(state, data)
print("state 2", state)

state 1 (2147483648, ())
state 2 (857489215995, ())


In [257]:
state, recovered_data = util.categoricals_pop(probs[2:3, :], precision)(state)
state, recovered_data

((2147483648, ()), array([3]))

In [243]:
util.categoricals_pop(probs[-3:-2, :], precision)(state)

((7007908502, ()), array([2]))

# Encoding with WaveNet

In [5]:
# torch.cuda.set_device(1)
prfx = '../pytorch-wavenet'
mpld3.disable_notebook()

In [18]:
snapshot_name = "48"
with open(f"{prfx}/configs/{snapshot_name}.json") as f:
    data = f.read()
config = json.loads(data)
wavenet_args = config["wavenet_args"]
train_args = config["train_args"]
batch_size = train_args["batch_size"]
epochs = train_args["epochs"]
weight_decay  = train_args["weight_decay"]
continue_training_at_step = train_args["continue_training_at_step"]
snapshot_path = f"snapshots/{snapshot_name}"
snapshot_interval = train_args["snapshot_interval"]
lr = train_args["lr"]
device_name = config["device"]
dataset_path = config["dataset_path"]
load_path = config["load_path"]
type = config.get("type", "wavenet")
assert type == "wavenet"
class WaveNetWrapper(WaveNet):
    def __init__(self, *args, **kwargs):
        print("kwargs", kwargs)
        super().__init__(*args, **kwargs)

    def forward(self, input):
        return super().forward(self.one_hot(input))
    
    def one_hot(self, input):
        one_hot_input = torch.zeros(input.size(0), self.out_channels, input.size(1), device=torch.device('cpu'))
        one_hot_input.scatter_(1, input.unsqueeze(1), 1.)
        return one_hot_input
model = WaveNetWrapper(**wavenet_args)

kwargs {'out_channels': 256, 'layers': 20, 'stacks': 2, 'residual_channels': 512, 'gate_channels': 512, 'skip_out_channels': 512, 'kernel_size': 3, 'dropout': 0.05, 'cin_channels': -1, 'gin_channels': -1, 'weight_normalization': True, 'scalar_input': False, 'legacy': False}


In [7]:
dataset = AudioDataset(f'{prfx}/{dataset_path}', model.receptive_field*8)        

print('the dataset has ' + str(len(dataset)) + ' items')
print(f'each item has length {dataset.len_sample}')

the dataset has 75535 items
each item has length 32744


In [19]:
checkpoint_dict = torch.load(f"{prfx}/snapshots/{snapshot_name}/{snapshot_name}_34000", map_location='cpu')
model = nn.DataParallel(model)
model.load_state_dict(checkpoint_dict['model'])
model = model.module
# model = model.to('cuda')

In [20]:
torch.manual_seed(0)
np.random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [427]:
x = dataset[1].long().unsqueeze(0)

In [428]:
y = x[:,  1:]
torch.manual_seed(0)
y_hat = model(x)[:, :, :-1]
loss = F.cross_entropy(y_hat, y, reduction='sum') / x.size(0)

In [434]:
8 / ((loss/math.log(2)) / x.size(-1))

tensor(2.1238, grad_fn=<MulBackward0>)

In [None]:
2.1229

In [24]:
all_x_input = torch.cat([torch.zeros_like(x), x], dim=1)

In [25]:
torch.manual_seed(0)
np.random.seed(0)
all_x_output = model(all_x_input).detach()

In [26]:
model.eval()

WaveNetWrapper(
  (first_conv): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
  (conv_layers): ModuleList(
    (0): ResidualConv1dGLU(
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(2,))
      (conv1x1_out): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_skip): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
    )
    (1): ResidualConv1dGLU(
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(4,), dilation=(2,))
      (conv1x1_out): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_skip): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
    )
    (2): ResidualConv1dGLU(
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(8,), dilation=(4,))
      (conv1x1_out): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_skip): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
    )
    (3): ResidualConv1dGLU(
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(16,), dilation=

In [86]:
out = model.incremental_forward(softmax=False,quantize=False)

patching worked!


In [87]:
out[0,:,1]

tensor([-6.2772e+00, -7.2704e+00, -7.2539e+00, -7.2024e+00, -7.2312e+00,
        -6.8722e+00, -6.2317e+00, -6.0768e+00, -5.8338e+00, -5.2100e+00,
        -5.0316e+00, -4.8122e+00, -4.8287e+00, -4.8191e+00, -4.7682e+00,
        -4.8402e+00, -5.0198e+00, -4.8576e+00, -4.6458e+00, -4.6525e+00,
        -4.5661e+00, -4.4765e+00, -4.4847e+00, -4.2631e+00, -4.5779e+00,
        -4.7897e+00, -4.6694e+00, -5.1281e+00, -4.8903e+00, -5.2669e+00,
        -5.0791e+00, -4.8225e+00, -5.4319e+00, -5.7692e+00, -5.0441e+00,
        -5.2016e+00, -4.9732e+00, -4.8302e+00, -5.1365e+00, -5.2112e+00,
        -5.3023e+00, -5.3787e+00, -5.5397e+00, -5.5151e+00, -5.9245e+00,
        -5.8163e+00, -5.9782e+00, -5.2996e+00, -5.4737e+00, -5.5200e+00,
        -5.5589e+00, -5.4932e+00, -5.2754e+00, -5.1180e+00, -5.1033e+00,
        -4.5851e+00, -5.0620e+00, -4.9830e+00, -4.8531e+00, -4.7613e+00,
        -4.5676e+00, -4.4176e+00, -4.2322e+00, -3.9553e+00, -3.8968e+00,
        -3.3153e+00, -3.3538e+00, -3.3101e+00, -3.0

In [74]:
out[0,:,1]

tensor([-6.2772e+00, -7.2704e+00, -7.2539e+00, -7.2024e+00, -7.2312e+00,
        -6.8722e+00, -6.2317e+00, -6.0768e+00, -5.8338e+00, -5.2100e+00,
        -5.0316e+00, -4.8122e+00, -4.8287e+00, -4.8191e+00, -4.7682e+00,
        -4.8402e+00, -5.0198e+00, -4.8576e+00, -4.6458e+00, -4.6525e+00,
        -4.5661e+00, -4.4765e+00, -4.4847e+00, -4.2631e+00, -4.5779e+00,
        -4.7897e+00, -4.6694e+00, -5.1281e+00, -4.8903e+00, -5.2669e+00,
        -5.0791e+00, -4.8225e+00, -5.4319e+00, -5.7692e+00, -5.0441e+00,
        -5.2016e+00, -4.9732e+00, -4.8302e+00, -5.1365e+00, -5.2112e+00,
        -5.3023e+00, -5.3787e+00, -5.5397e+00, -5.5151e+00, -5.9245e+00,
        -5.8163e+00, -5.9782e+00, -5.2996e+00, -5.4737e+00, -5.5200e+00,
        -5.5589e+00, -5.4932e+00, -5.2754e+00, -5.1180e+00, -5.1033e+00,
        -4.5851e+00, -5.0620e+00, -4.9830e+00, -4.8531e+00, -4.7613e+00,
        -4.5676e+00, -4.4176e+00, -4.2322e+00, -3.9553e+00, -3.8968e+00,
        -3.3153e+00, -3.3538e+00, -3.3101e+00, -3.0

In [49]:
first_x_input = torch.cat([torch.zeros_like(x), torch.zeros_like(x)], dim=1)
first_x_input[0, -1] = x[0, 0]
first_x_input

tensor([[  0,   0,   0,  ...,   0,   0, 189]])

In [61]:
torch.manual_seed(0)
np.random.seed(0)
first_x_output = model(first_x_input).detach()

In [62]:
all_x_output[0, :, x.size(-1)]

tensor([-11.9625, -13.2330, -12.2922, -12.2994, -12.1627, -12.3003, -12.2914,
        -11.8946, -11.8871, -11.0082, -10.7952, -10.3780, -10.3103,  -9.7176,
         -9.7637,  -9.4829,  -9.4935,  -9.1279,  -9.1220,  -9.1577,  -9.3137,
         -9.6153,  -9.3507,  -9.0815,  -9.1525,  -8.8211,  -8.3858,  -7.9498,
         -7.2788,  -6.7857,  -6.6121,  -6.2381,  -6.1908,  -6.1728,  -6.0492,
         -5.9306,  -5.8768,  -5.6382,  -5.6412,  -5.2648,  -4.9481,  -4.7011,
         -4.7615,  -4.5728,  -4.5605,  -4.1412,  -4.1418,  -4.1206,  -3.9943,
         -3.9615,  -3.7657,  -3.7484,  -3.3526,  -3.2441,  -2.9082,  -2.9404,
         -2.9339,  -2.6680,  -2.7658,  -2.6940,  -2.7051,  -2.4763,  -2.5125,
         -2.4400,  -2.3020,  -2.2643,  -2.3682,  -2.2020,  -2.2429,  -2.2758,
         -2.3785,  -2.2099,  -2.2819,  -2.4777,  -2.2770,  -2.4424,  -2.5678,
         -2.5204,  -2.5705,  -2.6369,  -2.5514,  -2.6978,  -2.7095,  -2.7936,
         -2.7543,  -2.7318,  -2.7664,  -2.6829,  -3.1051,  -3.13

In [65]:
first_x_output[0, :, -1]

tensor([-11.5277, -12.6183, -11.7072, -11.7662, -11.5636, -11.6795, -11.6077,
        -11.3051, -11.2890, -10.4578, -10.2337,  -9.8175,  -9.7567,  -9.2041,
         -9.3712,  -9.0331,  -9.0302,  -8.7036,  -8.7306,  -8.7778,  -8.8900,
         -9.0959,  -8.8060,  -8.5355,  -8.6070,  -8.3687,  -8.0218,  -7.6326,
         -7.0911,  -6.5718,  -6.3698,  -5.9204,  -5.6828,  -5.6120,  -5.5672,
         -5.3691,  -5.3630,  -5.0619,  -5.0921,  -4.7758,  -4.4459,  -4.1894,
         -4.2153,  -4.0333,  -4.0095,  -3.6975,  -3.5521,  -3.3636,  -3.2109,
         -3.1330,  -3.0753,  -3.1495,  -2.7864,  -2.6340,  -2.3848,  -2.3327,
         -2.1958,  -2.1415,  -2.2187,  -2.2180,  -2.1524,  -1.9440,  -2.0950,
         -2.0151,  -1.8984,  -1.8868,  -2.0494,  -2.0144,  -1.9881,  -2.0782,
         -2.1385,  -2.0467,  -2.0528,  -2.3017,  -2.1262,  -2.1894,  -2.1996,
         -2.1059,  -2.1616,  -2.2772,  -2.2939,  -2.5053,  -2.4807,  -2.3863,
         -2.5066,  -2.5174,  -2.5686,  -2.4618,  -2.8314,  -2.85

## incremental_forward

In [52]:
def incremental_forward(self, all_x, initial_input=None, c=None, g=None,
                        T=100, test_inputs=None,
                        tqdm=lambda x: x, softmax=False, quantize=False,
                        log_scale_min=-7.0):
    print("patching worked!")
    """Incremental forward step

    Due to linearized convolutions, inputs of shape (B x C x T) are reshaped
    to (B x T x C) internally and fed to the network for each time step.
    Input of each time step will be of shape (B x 1 x C).

    Args:
        initial_input (Tensor): Initial decoder input, (B x C x 1)
        c (Tensor): Local conditioning features, shape (B x C' x T)
        g (Tensor): Global conditioning features, shape (B x C'' or B x C''x 1)
        T (int): Number of time steps to generate.
        test_inputs (Tensor): Teacher forcing inputs (for debugging)
        tqdm (lamda) : tqdm
        softmax (bool) : Whether applies softmax or not
        quantize (bool): Whether quantize softmax output before feeding the
            network output to input for the next time step. TODO: rename
        log_scale_min (float):  Log scale minimum value.

    Returns:
        Tensor: Generated one-hot encoded samples. B x C x T　
            or scaler vector B x 1 x T
    """
    self.clear_buffer()
    B = 1

    # Note: shape should be **(B x T x C)**, not (B x C x T) opposed to
    # batch forward due to linealized convolution
    if test_inputs is not None:
        if self.scalar_input:
            if test_inputs.size(1) == 1:
                test_inputs = test_inputs.transpose(1, 2).contiguous()
        else:
            if test_inputs.size(1) == self.out_channels:
                test_inputs = test_inputs.transpose(1, 2).contiguous()

        B = test_inputs.size(0)
        if T is None:
            T = test_inputs.size(1)
        else:
            T = max(T, test_inputs.size(1))
    # cast to int in case of numpy.int64...
    T = int(T)

    # Global conditioning
    if g is not None:
        if self.embed_speakers is not None:
            g = self.embed_speakers(g.view(B, -1))
            # (B x gin_channels, 1)
            g = g.transpose(1, 2)
            assert g.dim() == 3
#     g_btc = _expand_global_features(B, T, g, bct=False)

    # Local conditioning
    if c is not None and self.upsample_conv is not None:
        # B x 1 x C x T
        c = c.unsqueeze(1)
        for f in self.upsample_conv:
            c = f(c)
        # B x C x T
        c = c.squeeze(1)
        assert c.size(-1) == T
    if c is not None and c.size(-1) == T:
        c = c.transpose(1, 2).contiguous()

    outputs = []
    if initial_input is None:
        if self.scalar_input:
            initial_input = torch.zeros(B, 1, 1)
        else:
            initial_input = torch.zeros(B, 1, self.out_channels)
            initial_input[:, :, 127] = 1  # TODO: is this ok?
        # https://github.com/pytorch/pytorch/issues/584#issuecomment-275169567
        if next(self.parameters()).is_cuda:
            initial_input = initial_input.cuda()
    else:
        if initial_input.size(1) == self.out_channels:
            initial_input = initial_input.transpose(1, 2).contiguous()

    current_input = initial_input

    for t in tqdm(range(all_x.size(-1))):
        if t % 1000 == 0:
            print("iteration", t)
        if test_inputs is not None and t < test_inputs.size(1):
            current_input = test_inputs[:, t, :].unsqueeze(1)
        else:
            if t > 0:
#                 print("x.size()", all_x.size())
                current_input = all_x[:, :, t-1]

        # Conditioning features for single time step
        ct = None if c is None else c[:, t, :].unsqueeze(1)
        gt = None if g is None else g_btc[:, t, :].unsqueeze(1)

        x = current_input
        x = self.first_conv.incremental_forward(x)
        skips = None
        for f in self.conv_layers:
            x, h = f.incremental_forward(x, ct, gt)
            if self.legacy:
                skips = h if skips is None else (skips + h) * math.sqrt(0.5)
            else:
                skips = h if skips is None else (skips + h)
        x = skips
        for f in self.last_conv_layers:
            try:
                x = f.incremental_forward(x)
            except AttributeError:
                x = f(x)

        # Generate next input by sampling
        if self.scalar_input:
            x = sample_from_discretized_mix_logistic(
                x.view(B, -1, 1), log_scale_min=log_scale_min)
        else:
            x = F.softmax(x.view(B, -1), dim=1) if softmax else x.view(B, -1)
            if quantize:
                sample = np.random.choice(
                    np.arange(self.out_channels), p=x.view(-1).data.cpu().numpy())
                x.zero_()
                x[:, sample] = 1.0
        outputs += [x.data]
    # T x B x C
    outputs = torch.stack(outputs)
    # B x C x T
    outputs = outputs.transpose(0, 1).transpose(1, 2).contiguous()

    self.clear_buffer()
    return outputs

import types
model.incremental_forward = types.MethodType(incremental_forward, model)

In [331]:
x_uncompressed = x # x[:, :8000]

In [332]:
logits = model.incremental_forward(model.one_hot(x_uncompressed))

patching worked!
iteration 0
iteration 1000
iteration 2000
iteration 3000
iteration 4000
iteration 5000
iteration 6000
iteration 7000
iteration 8000
iteration 9000
iteration 10000
iteration 11000
iteration 12000
iteration 13000
iteration 14000
iteration 15000
iteration 16000
iteration 17000
iteration 18000
iteration 19000
iteration 20000
iteration 21000
iteration 22000
iteration 23000
iteration 24000
iteration 25000
iteration 26000
iteration 27000
iteration 28000
iteration 29000
iteration 30000
iteration 31000
iteration 32000


```python
def categoricals_append(probs, precision):
    """Assume that the last dim of probs contains the probability vectors,
    i.e. np.sum(probs, axis=-1) == ones"""
    # Flatten all but last dim of probs
    probs = np.reshape(probs, (-1, np.shape(probs)[-1]))
    cdfs = [categorical_cdf(p, precision) for p in probs]
    def append(state, data):
        data = np.ravel(data)
        return non_uniforms_append(precision, cdfs)(state, data)
    return append

def categoricals_pop(probs, precision):
    """Assume that the last dim of probs contains the probability vectors,
    i.e. np.sum(probs, axis=-1) == ones"""
    # Flatten all but last dim of probs
    data_shape = np.shape(probs)[:-1]
    probs = np.reshape(probs, (-1, np.shape(probs)[-1]))
    cdfs = [categorical_cdf(p, precision) for p in probs]
    ppfs = [categorical_ppf(p, precision) for p in probs]

    def pop(state):
        state, symbols = non_uniforms_pop(precision, ppfs, cdfs)(state)
        return state, np.reshape(symbols, data_shape)
    return pop
```

In [420]:
precision = 21

In [421]:
probs.shape

(32744, 256)

In [422]:
# probs = np.array([[1/32, 31/32], [1/4, 3/4]])
probs = F.softmax(logits, dim=1).squeeze().numpy().transpose()

In [423]:
x_uncompressed[0, -1]

tensor(208)

In [424]:
data = x_uncompressed.squeeze().numpy()
# print("data", data)
state = rans.x_init
# print("state 1", state)
state = util.categoricals_append(probs, precision)(state, data)
# print("state 2", state)

In [425]:
flat_state = rans.flatten(state)
flat_state.dtype, flat_state.shape

(dtype('uint32'), (3856,))

In [426]:
32744 / (3856*4)

2.1229253112033195

In [182]:
data.shape

(8000,)

In [328]:
def incremental_forward_recover(self, state, length, initial_input=None, c=None, g=None,
                        T=100, test_inputs=None,
                        tqdm=lambda x: x, softmax=False, quantize=False,
                        log_scale_min=-7.0):
    print("patching worked!")
    """Incremental forward step

    Due to linearized convolutions, inputs of shape (B x C x T) are reshaped
    to (B x T x C) internally and fed to the network for each time step.
    Input of each time step will be of shape (B x 1 x C).

    Args:
        initial_input (Tensor): Initial decoder input, (B x C x 1)
        c (Tensor): Local conditioning features, shape (B x C' x T)
        g (Tensor): Global conditioning features, shape (B x C'' or B x C''x 1)
        T (int): Number of time steps to generate.
        test_inputs (Tensor): Teacher forcing inputs (for debugging)
        tqdm (lamda) : tqdm
        softmax (bool) : Whether applies softmax or not
        quantize (bool): Whether quantize softmax output before feeding the
            network output to input for the next time step. TODO: rename
        log_scale_min (float):  Log scale minimum value.

    Returns:
        Tensor: Generated one-hot encoded samples. B x C x T　
            or scaler vector B x 1 x T
    """
    self.clear_buffer()
    B = 1

    # Note: shape should be **(B x T x C)**, not (B x C x T) opposed to
    # batch forward due to linealized convolution
    if test_inputs is not None:
        if self.scalar_input:
            if test_inputs.size(1) == 1:
                test_inputs = test_inputs.transpose(1, 2).contiguous()
        else:
            if test_inputs.size(1) == self.out_channels:
                test_inputs = test_inputs.transpose(1, 2).contiguous()

        B = test_inputs.size(0)
        if T is None:
            T = test_inputs.size(1)
        else:
            T = max(T, test_inputs.size(1))
    # cast to int in case of numpy.int64...
    T = int(T)

    # Global conditioning
    if g is not None:
        if self.embed_speakers is not None:
            g = self.embed_speakers(g.view(B, -1))
            # (B x gin_channels, 1)
            g = g.transpose(1, 2)
            assert g.dim() == 3
#     g_btc = _expand_global_features(B, T, g, bct=False)

    # Local conditioning
    if c is not None and self.upsample_conv is not None:
        # B x 1 x C x T
        c = c.unsqueeze(1)
        for f in self.upsample_conv:
            c = f(c)
        # B x C x T
        c = c.squeeze(1)
        assert c.size(-1) == T
    if c is not None and c.size(-1) == T:
        c = c.transpose(1, 2).contiguous()

    outputs = []
    if initial_input is None:
        if self.scalar_input:
            initial_input = torch.zeros(B, 1, 1)
        else:
            initial_input = torch.zeros(B, 1, self.out_channels)
            initial_input[:, :, 127] = 1  # TODO: is this ok?
        # https://github.com/pytorch/pytorch/issues/584#issuecomment-275169567
        if next(self.parameters()).is_cuda:
            initial_input = initial_input.cuda()
    else:
        if initial_input.size(1) == self.out_channels:
            initial_input = initial_input.transpose(1, 2).contiguous()

    current_input = initial_input
    recovered_data = np.zeros(length, dtype=np.int32)
    for t in tqdm(range(length)):
        if t % 1000 == 0:
            print("iteration", t)
        if test_inputs is not None and t < test_inputs.size(1):
            current_input = test_inputs[:, t, :].unsqueeze(1)
        else:
            if t > 0:
                current_input_scalar = recovered_data[t-1]
#                 print("current_input_scalar", current_input_scalar)
                current_input = torch.zeros(1, 256)
                current_input[0, current_input_scalar] = 1.
                # all_x[:, :, t-1]
                

        # Conditioning features for single time step
        ct = None if c is None else c[:, t, :].unsqueeze(1)
        gt = None if g is None else g_btc[:, t, :].unsqueeze(1)

        x = current_input
        x = self.first_conv.incremental_forward(x)
        skips = None
        for f in self.conv_layers:
            x, h = f.incremental_forward(x, ct, gt)
            if self.legacy:
                skips = h if skips is None else (skips + h) * math.sqrt(0.5)
            else:
                skips = h if skips is None else (skips + h)
        x = skips
        for f in self.last_conv_layers:
            try:
                x = f.incremental_forward(x)
            except AttributeError:
                x = f(x)

        # Generate next input by sampling
        if self.scalar_input:
            x = sample_from_discretized_mix_logistic(
                x.view(B, -1, 1), log_scale_min=log_scale_min)
        else:
            x = F.softmax(x.view(B, -1), dim=1) if softmax else x.view(B, -1)
            if quantize:
                sample = np.random.choice(
                    np.arange(self.out_channels), p=x.view(-1).data.cpu().numpy())
                x.zero_()
                x[:, sample] = 1.0
        outputs += [x.data]
        probs = F.softmax(outputs[-1], dim=-1).numpy()
#         print(f"probs at t={t}: {probs}")
        state, scalar_input = util.categoricals_pop(probs, precision)(state)
#         print("output.shape", output.shape)
        scalar_input = scalar_input[0]
        recovered_data[t] = scalar_input
        print("scalar_input: ", scalar_input)

    # T x B x C
    outputs = torch.stack(outputs)
    # B x C x T
    outputs = outputs.transpose(0, 1).transpose(1, 2).contiguous()


    self.clear_buffer()
    return recovered_data

import types
model.incremental_forward_recover = types.MethodType(incremental_forward_recover, model)

In [329]:
data[0:100]

array([189,  68,  48,  42,  41,  36,  29,  22,  18,  16,  16,  18,  23,
        30,  42,  60,  78,  83,  75,  70,  95, 193, 209, 218, 224, 224,
       221, 215, 199, 179, 203, 219, 222, 220, 216, 208, 190, 146, 169,
       199, 208, 204, 184,  73,  45,  32,  28,  29,  32,  36,  39,  38,
        33,  29,  27,  29,  37,  57, 111, 181, 189, 185, 146,  71,  72,
       188, 217, 227, 230, 229, 224, 215, 212, 217, 218, 217, 216, 212,
       199, 169,  86,  89,  83,  75,  79,  81,  71,  56,  45,  39,  37,
        41,  50,  61,  67,  68,  66,  57,  53,  63])

In [330]:
recovered_data = model.incremental_forward_recover(state, length=data.size)

patching worked!
iteration 0
scalar_input:  189
scalar_input:  68
scalar_input:  48
scalar_input:  42
scalar_input:  41
scalar_input:  36
scalar_input:  29
scalar_input:  22
scalar_input:  18
scalar_input:  16
scalar_input:  16
scalar_input:  18
scalar_input:  23
scalar_input:  30
scalar_input:  42
scalar_input:  60
scalar_input:  78
scalar_input:  83
scalar_input:  75
scalar_input:  70
scalar_input:  95
scalar_input:  193
scalar_input:  209
scalar_input:  218
scalar_input:  224
scalar_input:  224
scalar_input:  221
scalar_input:  215
scalar_input:  199
scalar_input:  179
scalar_input:  203
scalar_input:  219
scalar_input:  222
scalar_input:  220
scalar_input:  216
scalar_input:  208
scalar_input:  190
scalar_input:  146
scalar_input:  169
scalar_input:  199
scalar_input:  208
scalar_input:  204
scalar_input:  184


KeyboardInterrupt: 

In [6]:
state, recovered_data = util.categoricals_pop(probs, precision)(state)
print("state", state)
print("recovered_bits", recovered_data)

data [1 1]
state 1 (2147483648, ())
state 2 (2956104068, ())
state (2147483648, ())
recovered_bits [1 1]
