In [1]:
from model import Network
from data import load_cifar10_dataloaders, whitening_transformation
import torch
import torchvision.transforms as transforms
import torch.nn as nn
from criterion import ReconstructImageFromFCLoss



In [2]:
# load to model
model_config = './model_config/fc2_cocktail_party_instance_wout_bias.json'
checkpoint_path = './checkpoints/121923_fc2_cocktail_party_cifar10_pretraining_wout_bias.pth'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Network(model_config)
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['state_dict'])
model = model.to(device)

# get val loader
normalize_mean, normalize_std = (0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)
batch_size = 8
data_path = './data'
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(normalize_mean, normalize_std),
])

_, val_dataloader = load_cifar10_dataloaders(data_path, batch_size, transform)
selected_val_batch_data, selected_val_batch_label = next(iter(val_dataloader))
selected_val_batch_data = selected_val_batch_data.to(device)
selected_val_batch_label = selected_val_batch_label.to(device)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
# receiving gradients
model.zero_grad()
criterion = nn.CrossEntropyLoss()
output = model(selected_val_batch_data.reshape(batch_size, -1))
loss = criterion(output, selected_val_batch_label)
loss.backward()
gradient_of_layers = []
for param in model.parameters():
    gradient_of_layers.append(param.grad.data.clone().to('cpu'))
print([x.size() for x in gradient_of_layers])

[torch.Size([256, 3072]), torch.Size([10, 256])]


In [4]:
pca_whitened_gradients = []
zca_whitened_gradients = []
for gradient in gradient_of_layers:
    pca_whitened_gradient, zca_whitened_gradient = whitening_transformation(gradient)
    pca_whitened_gradients.append(pca_whitened_gradient)
    zca_whitened_gradients.append(zca_whitened_gradient)

In [7]:
# criterion output testing
unmixing_matrix = torch.rand((selected_val_batch_data.size(0), gradient_of_layers[0].size(0)), requires_grad=True)
reconstruction_loss = ReconstructImageFromFCLoss(32, 32, 3, 1, 1, 1)
optimizer = torch.optim.Adam([unmixing_matrix])

for iter_idx in range(25000):
    optimizer.zero_grad()
    # out_score, non_gaussianity_score, total_variance_score, mutual_independence_score
    loss, _, _, _ = reconstruction_loss(unmixing_matrix, zca_whitened_gradients[0])
    loss.backward()
    optimizer.step()
    
    if (iter_idx + 1) % 1000 == 0 or iter_idx == 0:
        print('loss: {}'.format(loss.item()))

loss: 2.0979433059692383
loss: 0.853891134262085
loss: 0.8182715773582458
loss: 0.8152432441711426
loss: 0.8132375478744507
loss: 0.8115373253822327
loss: 0.8100262880325317
loss: 0.8089626431465149
loss: 0.8082752227783203
loss: 0.8077040314674377
loss: 0.8072206974029541
loss: 0.8068822622299194
loss: 0.8065118789672852
loss: 0.8063374161720276
loss: 0.8061027526855469
loss: 0.8059698343276978
loss: 0.8058511018753052
loss: 0.8057987093925476
loss: 0.8056550025939941
loss: 0.8055974841117859
loss: 0.8055146336555481
loss: 0.8054519891738892
loss: 0.8054258823394775
loss: 0.8053689002990723
loss: 0.8053085803985596
loss: 0.8052671551704407


In [15]:
with torch.no_grad():
    estimated_img = unmixing_matrix @ zca_whitened_gradients[0]
    img = transforms.ToPILImage()(-1 * torch.clamp(estimated_img[5].reshape(3, 32, 32), min=-1, max=1) * 100)
    img.show()
    img = transforms.ToPILImage()(torch.clamp(selected_val_batch_data[5].reshape(3, 32, 32), min=-1, max=1))
    img.show()
    # print(estimated_img[0])

In [15]:
with torch.no_grad():
    clamped_img = torch.clamp(estimated_img[0].reshape(3, 32, 32), min=-1, max=1)
    print(clamped_img[clamped_img == -1].shape)
    print(clamped_img[clamped_img == 1].shape)

torch.Size([0])
torch.Size([0])


In [11]:
estimated_img[0].reshape(3, 32, 32)

tensor([[[ 3.2783e-04,  2.9368e-04,  3.0880e-04,  ...,  4.9177e-04,
           5.1527e-04,  5.2216e-04],
         [ 2.7950e-04,  2.5961e-04,  2.6367e-04,  ...,  3.4911e-04,
           3.9457e-04,  3.9479e-04],
         [ 1.7792e-04,  1.3845e-04,  1.4941e-04,  ...,  3.0867e-04,
           3.8751e-04,  4.5424e-04],
         ...,
         [-4.8327e-05, -1.1369e-05,  5.7595e-05,  ..., -9.0523e-05,
          -4.6151e-05, -5.3062e-05],
         [-2.3185e-04, -1.5868e-04, -4.3910e-05,  ..., -1.1516e-04,
          -1.0162e-04, -5.3834e-05],
         [-2.3476e-04, -2.0716e-04, -1.7674e-04,  ..., -1.2411e-04,
          -1.1664e-04, -2.0626e-05]],

        [[ 2.6655e-04,  2.6097e-04,  2.8234e-04,  ...,  5.3675e-04,
           5.4891e-04,  5.5572e-04],
         [ 2.3377e-04,  2.4697e-04,  2.6934e-04,  ...,  4.1941e-04,
           4.3809e-04,  4.2729e-04],
         [ 1.4929e-04,  1.5777e-04,  1.9760e-04,  ...,  3.4826e-04,
           4.1596e-04,  4.6848e-04],
         ...,
         [-5.0159e-05, -2