In [33]:
# autoreload magic

%load_ext autoreload
%autoreload 2

In [34]:
import torch
import numpy as np
import copy

from mldec.utils import evaluation, training
from mldec.models import initialize

from mldec.datasets import toy_problem_data

#### Small test: verify that virtual training dynamics are identical to training with real data.

In [104]:
n = 8
config = {
    'device': 'cpu',
    'opt': 'adam',
    "model": "cnn",
    "conv_channels": 4, 
    "kernel_size": 3, 
    "n_layers": 3, 
    "dropout": 0.00, # Important: Dropout needs to be zero if we want models to do predictions consistently enough to compare
    "input_dim": n-1,
    "output_dim": n,
    "batch_size": 32,
    "lr": 0.003,
    "max_epochs": 5,
    "patience": 222,
    "mode": 'train',
    "n_train": 512,
}

EXP = 1
if EXP == 1:
    # testing for toy problem data.
    dataset_config = {
        'p': 0.1,
        'alpha': 0.7,
        'pcm': toy_problem_data.repetition_pcm(n),
        "sos_eos": config.get("sos_eos", None),
    }
    dataset_module = toy_problem_data

device = torch.device(config.get('device'))
model_wrapper = initialize.initialize_model(config)
optimizer, scheduler = training.initialize_optimizer(config, model_wrapper.model.parameters())
early_stopping = training.EarlyStopping(patience=config.get('patience'))
max_val_acc = -1
# tot_params, trainable_params = initialize.count_parameters(model_wrapper.model)

_, _, train_weights = dataset_module.create_dataset_training(n, dataset_config)
X, Y, val_weights = dataset_module.create_dataset_training(n, dataset_config)

train_weights_np = train_weights.numpy()
val_weights = torch.tensor(val_weights, dtype=torch.float32)  # true distribution of bitstrings  
X, Y, val_weights = X.to(device), Y.to(device), val_weights.to(device)


  val_weights = torch.tensor(val_weights, dtype=torch.float32)  # true distribution of bitstrings


In [105]:
batch_size = config.get('batch_size')
batched_virtual_data = []
batched_real_data = []
# build a train set one batch at a time
# this will accumulate a histogram of the training set over all batches
downsampled_train_weights = np.zeros_like(train_weights_np)
n_batches = config.get('n_train') // batch_size

for _ in range(n_batches):
    Xb, Yb, train_weightsb, histb = dataset_module.sample_virtual_XY(train_weights_np, batch_size, n, dataset_config)
    num_each_data = (histb * batch_size).astype(int)
    Xb, Yb, train_weightsb = Xb.to(device), Yb.to(device), train_weightsb.to(device)
    downsampled_train_weights += histb
    batched_virtual_data.append((Xb, Yb, train_weightsb))

    # now build real data from histb
    real_X = []
    real_Y = []
    for i in range(len(num_each_data)):
        for _ in range(num_each_data[i]):
            real_X.append(X[i])
            real_Y.append(Y[i])

    real_X = torch.stack(real_X).to(device)
    real_Y = torch.stack(real_Y).to(device)
    # print(Xb)
    # print(train_weightsb * batch_size)
    # print(real_X)
    batched_real_data.append((real_X, real_Y))
    
downsampled_train_weights /= n_batches # histogram of the training set


In [106]:
criterion_virtual = evaluation.WeightedSequenceLoss(torch.nn.BCEWithLogitsLoss)
criterion_real = torch.nn.BCEWithLogitsLoss()
virtual_losses = []

torch.manual_seed(111)
model_wrapper = initialize.initialize_model(config)
optimizer, scheduler = training.initialize_optimizer(config, model_wrapper.model.parameters())
for epoch in range(5):
    train_loss = 0        
    for i in range(n_batches):
        Xb, Yb, weightsb = batched_virtual_data[i]
        # Do gradient descent on a virtual batch of data
        virtual_loss = model_wrapper.training_step(Xb, Yb, weightsb, optimizer, criterion_virtual)
        virt_loss = virtual_loss.item()
        train_loss += virt_loss
        virtual_losses.append(virt_loss)

    train_loss = train_loss / n_batches

# regenerate an identical model, and then assert that the virtual training process
# produces an identical sequence of losses and predictions as a `real` training process.
real_losses = []
torch.manual_seed(111)
model_wrapper = initialize.initialize_model(config)
optimizer, scheduler = training.initialize_optimizer(config, model_wrapper.model.parameters())
for epoch in range(5):
    train_loss = 0        
    for i in range(n_batches):
        Xb, Yb = batched_real_data[i]
        # Do gradient descent on a virtual batch of data
        real_loss = model_wrapper.real_training_step(Xb, Yb, optimizer, criterion_real)
        real_loss = real_loss.item()
        real_losses.append(real_loss)
        train_loss += real_loss

    train_loss = train_loss / n_batches


In [107]:
np.testing.assert_allclose(virtual_losses, real_losses, rtol=1e-5, atol=1e-5)