In [3]:
from model import Network
from data import load_cifar10_dataloaders, whitening_transformation
import torch
import torchvision.transforms as transforms
import torch.nn as nn
from criterion import ReconstructImageFromFCLoss



In [4]:
# load to model
model_config = './model_config/fc2_cocktail_party_instance_wout_bias.json'
checkpoint_path = './checkpoints/121923_fc2_cocktail_party_cifar10_pretraining_wout_bias.pth'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Network(model_config)
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['state_dict'])
model = model.to(device)

# get val loader
normalize_mean, normalize_std = (0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)
batch_size = 8
data_path = './data'
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(normalize_mean, normalize_std),
])

_, val_dataloader = load_cifar10_dataloaders(data_path, batch_size, transform)
selected_val_batch_data, selected_val_batch_label = next(iter(val_dataloader))
selected_val_batch_data = selected_val_batch_data.to(device)
selected_val_batch_label = selected_val_batch_label.to(device)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
# receiving gradients
model.zero_grad()
criterion = nn.CrossEntropyLoss()
output = model(selected_val_batch_data.reshape(batch_size, -1))
loss = criterion(output, selected_val_batch_label)
loss.backward()
gradient_of_layers = []
for param in model.parameters():
    gradient_of_layers.append(param.grad.data.clone().to('cpu'))
print([x.size() for x in gradient_of_layers])

[torch.Size([256, 3072]), torch.Size([10, 256])]


In [None]:
pca_whitened_gradients = []
zca_whitened_gradients = []
for gradient in gradient_of_layers:
    pca_whitened_gradient, zca_whitened_gradient = whitening_transformation(gradient)
    pca_whitened_gradients.append(pca_whitened_gradient)
    zca_whitened_gradients.append(zca_whitened_gradient)

In [6]:
# criterion output testing
unmixing_matrix = torch.rand((selected_val_batch_data.size(0), gradient_of_layers[0].size(0)), requires_grad=True)
reconstruction_loss = ReconstructImageFromFCLoss(32, 32, 3, 1, 1, 1)
optimizer = torch.optim.Adam([unmixing_matrix])

for iter_idx in range(25000):
    optimizer.zero_grad()
    # out_score, non_gaussianity_score, total_variance_score, mutual_independence_score
    loss, _, _, _ = reconstruction_loss(unmixing_matrix, pca_whitened_gradients[0])
    loss.backward()
    optimizer.step()
    
    if (iter_idx + 1) % 1000 == 0 or iter_idx == 0:
        print('loss: {}'.format(loss.item()))

loss: 1.9379727840423584
loss: 0.9412828087806702
loss: 0.9356341361999512
loss: 0.932887077331543
loss: 0.9315226078033447
loss: 0.9314039349555969
loss: 0.9289105534553528
loss: 0.9283108711242676
loss: 0.9250630140304565
loss: 0.9226387143135071
loss: 0.9224544167518616
loss: 0.9200116991996765
loss: 0.9200348854064941
loss: 0.9200003743171692
loss: 0.9200664758682251
loss: 0.9199835062026978
loss: 0.9199564456939697
loss: 0.9200427532196045
loss: 0.9188419580459595
loss: 0.9160979390144348
loss: 0.9162095785140991
loss: 0.9161527752876282
loss: 0.9160701036453247
loss: 0.9161638021469116
loss: 0.9161427617073059
loss: 0.9149090051651001


In [18]:
with torch.no_grad():
    estimated_img = unmixing_matrix @ zca_whitened_gradients[0]
    img = transforms.ToPILImage()(torch.clamp(estimated_img[4].reshape(3, 32, 32), min=-1, max=1))
    img.show()
    img = transforms.ToPILImage()(torch.clamp(selected_val_batch_data[4].reshape(3, 32, 32), min=-1, max=1))
    img.show()
    # print(estimated_img[0])

In [15]:
with torch.no_grad():
    clamped_img = torch.clamp(estimated_img[0].reshape(3, 32, 32), min=-1, max=1)
    print(clamped_img[clamped_img == -1].shape)
    print(clamped_img[clamped_img == 1].shape)

torch.Size([0])
torch.Size([0])
