In [1]:
from model import Network
from data import load_cifar10_dataloaders
import torch
import torchvision.transforms as transforms
import torch.nn as nn
from criterion import NonGaussianityLoss, TotalVariationLoss, MutualIndependenceLoss, ReconstructImageFromFCLoss



In [2]:
# load to model
model_config = './model_config/fc2_cocktail_party_instance_wout_bias.json'
checkpoint_path = './checkpoints/121923_fc2_cocktail_party_cifar10_pretraining_wout_bias.pth'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Network(model_config)
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['state_dict'])
model = model.to(device)

# get val loader
normalize_mean, normalize_std = (0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)
batch_size = 8
data_path = './data'
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(normalize_mean, normalize_std),
])

_, val_dataloader = load_cifar10_dataloaders(data_path, batch_size, transform)
selected_val_batch_data, selected_val_batch_label = next(iter(val_dataloader))
selected_val_batch_data = selected_val_batch_data.to(device)
selected_val_batch_label = selected_val_batch_label.to(device)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
# receiving gradients
model.zero_grad()
criterion = nn.CrossEntropyLoss()
output = model(selected_val_batch_data.reshape(batch_size, -1))
loss = criterion(output, selected_val_batch_label)
loss.backward()
gradient_of_layers = []
for param in model.parameters():
    gradient_of_layers.append(param.grad.data.clone().to('cpu'))

In [4]:
print([x.size() for x in gradient_of_layers])

[torch.Size([256, 3072]), torch.Size([10, 256])]


In [5]:
# criterion output testing
unmixing_matrix = torch.rand((selected_val_batch_data.size(0), gradient_of_layers[0].size(0)), requires_grad=True)
reconstruction_loss = ReconstructImageFromFCLoss(32, 32, 3, 1, 1, 1)
optimizer = torch.optim.Adam([unmixing_matrix])

for iter_idx in range(10000):
    optimizer.zero_grad()
    loss = reconstruction_loss(unmixing_matrix, gradient_of_layers[0])
    loss.backward()
    optimizer.step()
    
    if (iter_idx + 1) % 1000 == 0 or iter_idx == 0:
        print('loss: {}'.format(loss.item()))

loss: 1.8976237773895264
loss: 0.8755009174346924
loss: 0.8752057552337646
loss: 0.8751492500305176
loss: 0.8751644492149353
loss: 0.8751627206802368
loss: 0.8751421570777893
loss: 0.8751373887062073
loss: 0.8751441240310669
loss: 0.8751401305198669
loss: 0.875136137008667


In [6]:
estimated_imgs = unmixing_matrix @ gradient_of_layers[0]

In [8]:
imgs = estimated_imgs.reshape(estimated_imgs.size(0), 3, 32, 32)

In [13]:
imgs

tensor([[[[ 4.8460e-05,  4.9524e-05,  5.1156e-05,  ..., -2.3359e-05,
           -1.6787e-05, -1.8177e-05],
          [ 4.8267e-05,  4.8632e-05,  4.8591e-05,  ..., -1.0638e-05,
           -1.2272e-05, -1.4980e-05],
          [ 4.7672e-05,  4.9267e-05,  4.9029e-05,  ..., -2.5022e-07,
           -8.1814e-06, -1.1367e-05],
          ...,
          [-3.0382e-05, -3.1026e-05, -3.1281e-05,  ..., -1.6945e-05,
           -1.4008e-05, -1.2181e-05],
          [-3.3678e-05, -3.2953e-05, -3.2148e-05,  ..., -1.4746e-05,
           -1.3227e-05, -1.0206e-05],
          [-3.5266e-05, -3.4505e-05, -3.3083e-05,  ..., -1.5124e-05,
           -1.0981e-05, -1.6106e-05]],

         [[ 4.8727e-05,  4.9562e-05,  5.1071e-05,  ..., -2.0755e-05,
           -1.4296e-05, -1.6014e-05],
          [ 4.7368e-05,  4.8341e-05,  4.9048e-05,  ..., -8.5515e-06,
           -1.0653e-05, -1.3530e-05],
          [ 4.5459e-05,  4.8267e-05,  4.9450e-05,  ...,  1.0017e-06,
           -7.7515e-06, -1.0736e-05],
          ...,
     

In [28]:
invTrans = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ],
                                                     std = [ 1/0.247, 1/0.243, 1/0.261 ]),
                                transforms.Normalize(mean = [ -0.4914, -0.4822, -0.4465 ],
                                                     std = [ 1., 1., 1. ]),
                               ])

In [29]:
img = transforms.ToPILImage()(invTrans(imgs[0] * 10000))

In [30]:
img.show()

In [26]:

# normalize_mean, normalize_std = (0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)
img = transforms.ToPILImage()(invTrans(selected_val_batch_data[0].to('cpu')))