In [None]:
from model import Network
from data import load_mnist_dataloaders, WhiteningTransformation
import torch
import torchvision.transforms as transforms
import torch.nn as nn
from criterion import ReconstructImageFromFCLoss

In [2]:
# load to model
model_config = './model_config/fc1_cocktail_party_mnist_instance.json'
checkpoint_path = './checkpoints/122123_fc1_cocktail_party_mnist_pretraining_wout_bias.pth'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Network(model_config)
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['state_dict'])
model = model.to(device)

# get val loader
normalize_mean, normalize_std = (0.1307,), (0.3081,)
batch_size = 4
data_path = './data'
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(normalize_mean, normalize_std),
])

_, val_dataloader = load_mnist_dataloaders(data_path, batch_size, transform)
selected_val_batch_data, selected_val_batch_label = next(iter(val_dataloader))
selected_val_batch_data = selected_val_batch_data.to(device)
selected_val_batch_label = selected_val_batch_label.to(device)

In [3]:
# receiving gradients
model.zero_grad()
criterion = nn.CrossEntropyLoss()
output = model(selected_val_batch_data.reshape(batch_size, -1))
loss = criterion(output, selected_val_batch_label)
loss.backward()
gradient_of_layers = []
for param in model.parameters():
    gradient_of_layers.append(param.grad.data.clone().to('cpu'))
print([x.size() for x in gradient_of_layers])

[torch.Size([10, 784])]


In [13]:
whitening_transform = WhiteningTransformation()
whitened_gradient = torch.from_numpy(whitening_transform.transform(gradient_of_layers[0].detach().numpy().T)).to(torch.float32).T

In [14]:
from criterion import ReconstructImageFromFCLoss
# criterion output testing
unmixing_matrix = torch.rand((selected_val_batch_data.size(0), gradient_of_layers[0].size(0)), requires_grad=True)
reconstruction_loss = ReconstructImageFromFCLoss(28, 28, 1, 1, 1, 1)
optimizer = torch.optim.Adam([unmixing_matrix])

for iter_idx in range(25000):
    optimizer.zero_grad()
    # out_score, non_gaussianity_score, total_variance_score, mutual_independence_score
    loss, _, _, _ = reconstruction_loss(unmixing_matrix, whitened_gradient)
    loss.backward()
    optimizer.step()
    
    if (iter_idx + 1) % 1000 == 0 or iter_idx == 0:
        print('loss: {}'.format(loss.item()))

loss: 1.9434760808944702
loss: 0.9226415753364563
loss: 0.9137480854988098
loss: 0.9062380790710449
loss: 0.9000566601753235
loss: 0.8952429294586182
loss: 0.8913980722427368
loss: 0.8887266516685486
loss: 0.8867244720458984
loss: 0.8849146962165833
loss: 0.8839080333709717
loss: 0.8831490874290466
loss: 0.8826785087585449
loss: 0.8818979263305664
loss: 0.881306529045105
loss: 0.8813421726226807
loss: 0.8807418346405029
loss: 0.8805824518203735
loss: 0.8811609148979187
loss: 0.8812450766563416
loss: 0.8800931572914124
loss: 0.8802552223205566
loss: 0.8800516724586487
loss: 0.8802169561386108
loss: 0.8801079392433167
loss: 0.8804312348365784
