In [1]:
from model import Network
from data import load_mnist_dataloaders, whitening_transformation
import torch
import torchvision.transforms as transforms
import torch.nn as nn
from criterion import ReconstructImageFromFCLoss



In [2]:
# load to model
model_config = './model_config/fc1_cocktail_party_mnist_instance.json'
checkpoint_path = './checkpoints/122123_fc1_cocktail_party_mnist_pretraining_wout_bias.pth'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Network(model_config)
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['state_dict'])
model = model.to(device)

# get val loader
normalize_mean, normalize_std = (0.1307,), (0.3081,)
batch_size = 4
data_path = './data'
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(normalize_mean, normalize_std),
])

_, val_dataloader = load_mnist_dataloaders(data_path, batch_size, transform)
selected_val_batch_data, selected_val_batch_label = next(iter(val_dataloader))
selected_val_batch_data = selected_val_batch_data.to(device)
selected_val_batch_label = selected_val_batch_label.to(device)

In [3]:
# receiving gradients
model.zero_grad()
criterion = nn.CrossEntropyLoss()
output = model(selected_val_batch_data.reshape(batch_size, -1))
loss = criterion(output, selected_val_batch_label)
loss.backward()
gradient_of_layers = []
for param in model.parameters():
    gradient_of_layers.append(param.grad.data.clone().to('cpu'))
print([x.size() for x in gradient_of_layers])

[torch.Size([10, 784])]


In [4]:
pca_whitened_gradients = []
zca_whitened_gradients = []
for gradient in gradient_of_layers:
    pca_whitened_gradient, zca_whitened_gradient = whitening_transformation(gradient)
    pca_whitened_gradients.append(pca_whitened_gradient)
    zca_whitened_gradients.append(zca_whitened_gradient)

In [8]:
pca_whitened_gradients[0].real

tensor([[ 1.7965e-03,  4.8222e-04, -9.1254e-05,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [-8.8423e-03, -6.6276e-02, -2.4433e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 9.5630e-03,  6.7600e-02,  7.8493e-01,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        ...,
        [-9.4683e-02, -2.1576e+00,  8.3856e-02,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 1.0718e-01,  7.6387e-02,  1.5459e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 9.1241e-02,  2.0765e+00, -7.7356e-02,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00]])

In [9]:
# criterion output testing
unmixing_matrix = torch.rand((selected_val_batch_data.size(0), gradient_of_layers[0].size(0)), requires_grad=True)
reconstruction_loss = ReconstructImageFromFCLoss(28, 28, 1, 1, 1, 1)
optimizer = torch.optim.Adam([unmixing_matrix])

for iter_idx in range(25000):
    optimizer.zero_grad()
    # out_score, non_gaussianity_score, total_variance_score, mutual_independence_score
    loss, _, _, _ = reconstruction_loss(unmixing_matrix, pca_whitened_gradients[0].real)
    loss.backward()
    optimizer.step()
    
    if (iter_idx + 1) % 1000 == 0 or iter_idx == 0:
        print('loss: {}'.format(loss.item()))

loss: 2.221757650375366
loss: 1.298161506652832
loss: 1.2820384502410889
loss: 1.267505407333374
loss: 1.2630513906478882
loss: 1.2488138675689697
loss: 1.2305573225021362
loss: 1.1980507373809814
loss: 1.1889417171478271
loss: 1.1802971363067627
loss: 1.1714259386062622
loss: 1.1722252368927002
loss: 1.1698486804962158
loss: 1.1612379550933838
loss: 1.160032868385315
loss: 1.160524606704712
loss: 1.1516461372375488
loss: 1.1472995281219482
loss: 1.1492469310760498
loss: 1.1450276374816895
loss: 1.1400073766708374
loss: 1.1437129974365234
loss: 1.1454581022262573
loss: 1.1385128498077393
loss: 1.1335762739181519
loss: 1.1267268657684326


In [13]:
with torch.no_grad():
    estimated_img = unmixing_matrix @ pca_whitened_gradients[0].real
    img = transforms.ToPILImage()(torch.clamp(estimated_img[3].reshape(1, 28, 28), min=-1, max=1))
    img.show()
    img = transforms.ToPILImage()(torch.clamp(selected_val_batch_data[3].reshape(1, 28, 28), min=-1, max=1))
    img.show()
    # print(estimated_img[0])

In [15]:
with torch.no_grad():
    clamped_img = torch.clamp(estimated_img[0].reshape(3, 32, 32), min=-1, max=1)
    print(clamped_img[clamped_img == -1].shape)
    print(clamped_img[clamped_img == 1].shape)

torch.Size([0])
torch.Size([0])


In [11]:
estimated_img[0].reshape(3, 32, 32)

tensor([[[ 3.2783e-04,  2.9368e-04,  3.0880e-04,  ...,  4.9177e-04,
           5.1527e-04,  5.2216e-04],
         [ 2.7950e-04,  2.5961e-04,  2.6367e-04,  ...,  3.4911e-04,
           3.9457e-04,  3.9479e-04],
         [ 1.7792e-04,  1.3845e-04,  1.4941e-04,  ...,  3.0867e-04,
           3.8751e-04,  4.5424e-04],
         ...,
         [-4.8327e-05, -1.1369e-05,  5.7595e-05,  ..., -9.0523e-05,
          -4.6151e-05, -5.3062e-05],
         [-2.3185e-04, -1.5868e-04, -4.3910e-05,  ..., -1.1516e-04,
          -1.0162e-04, -5.3834e-05],
         [-2.3476e-04, -2.0716e-04, -1.7674e-04,  ..., -1.2411e-04,
          -1.1664e-04, -2.0626e-05]],

        [[ 2.6655e-04,  2.6097e-04,  2.8234e-04,  ...,  5.3675e-04,
           5.4891e-04,  5.5572e-04],
         [ 2.3377e-04,  2.4697e-04,  2.6934e-04,  ...,  4.1941e-04,
           4.3809e-04,  4.2729e-04],
         [ 1.4929e-04,  1.5777e-04,  1.9760e-04,  ...,  3.4826e-04,
           4.1596e-04,  4.6848e-04],
         ...,
         [-5.0159e-05, -2

In [19]:
import numpy as np

from scipy import stats

rng = np.random.default_rng()

n = 3

A = rng.random(size=(n, n))

cov_array = A @ A.T  # make matrix symmetric positive definite

precision = np.linalg.inv(cov_array)

cov_object = stats.Covariance.from_precision(precision)

x = rng.multivariate_normal(np.zeros(n), cov_array, size=(10000))

x_ = cov_object.whiten(x)

np.cov(x_, rowvar=False)

AttributeError: module 'scipy.stats' has no attribute 'Covariance'

In [18]:
P

array([[0.06502751, 0.41818084, 0.34031409, 0.68974631, 0.583228  ],
       [0.35283933, 0.70921727, 0.67365035, 0.47174336, 0.56917156],
       [0.9907613 , 0.15305781, 0.23506229, 0.75293318, 0.47001596],
       [0.49764115, 0.64438953, 0.818335  , 0.44351651, 0.05026389],
       [0.58473874, 0.43702655, 0.69170254, 0.00476943, 0.47568853]])