In [1]:
import numpy as np
import matplotlib.pyplot as plt
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import Adam
from ae import AE, NAE
from modules import DeConvNet2, ConvNet2FC, FC_supermask_encode, FC_supermask_decode, FC_supermask_encode_nonstochastic, FC_supermask_decode_nonstochastic, FC_original_encode, FC_original_decode
from leaveout_dataset import MNISTLeaveOut

from leaveout_dataset import MNISTLeaveOut
from tqdm import tqdm

from sklearn.metrics import roc_auc_score
from utils import roc_btw_arr
from torchvision.utils import make_grid, save_image
from torchvision.transforms import ToTensor
from itertools import chain



In [2]:
device = 1

n_ae_epoch = 1
finetune_epoch = 50
gamma = 1.
l2_norm_reg = None
l2_norm_reg_en = None #0.0001 
spherical = True 
clip_grad = None
batch_size = 128
leave_out = 1
pruning_ratio = 0.5




In [3]:
def predict(m, dl, device, flatten=False):
    l_result = []
    for x, _ in dl:
        with torch.no_grad():
            if flatten:
                x = x.view(len(x), -1)
            pred = m.predict(x.cuda(device)).detach().cpu()
        l_result.append(pred)
    return torch.cat(l_result)


'''load dataset'''
ds = MNISTLeaveOut('dataset', [leave_out], split='training', transform=ToTensor(), download=True)
in_train_dl = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=10)
ds = MNISTLeaveOut('dataset', [leave_out], split='validation', transform=ToTensor(), download=True)
in_val_dl = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=10)
ds = MNISTLeaveOut('dataset', [leave_out], split='evaluation', transform=ToTensor(), download=True)
in_test_dl = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=10)

in_digits = list(set(list(range(10)))-set([leave_out]))
ds = MNISTLeaveOut('dataset', in_digits, split='validation', transform=ToTensor(), download=True)
out_val_dl = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=10)
ds = MNISTLeaveOut('dataset', in_digits, split='evaluation', transform=ToTensor(), download=True)
out_test_dl = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=10)



In [4]:
'''build model for RE loss weight training'''
z_dim = 17
encoder = FC_original_encode(device)
decoder = FC_original_decode(device)

# encoder = FC_supermask_encode_nonstochastic(device, sparsity = args.pruning_ratio)#ConvNet2FC(1, z_dim, nh=8, nh_mlp=1024, out_activation='linear')
# decoder = FC_supermask_decode_nonstochastic(device, sparsity = args.pruning_ratio) #DeConvNet2(z_dim, 1, nh=8, out_activation='sigmoid')
        
model = NAE(encoder, decoder, l2_norm_reg=l2_norm_reg, l2_norm_reg_en=l2_norm_reg_en, spherical=spherical, z_step=10, z_stepsize=0.2, z_noise_std=0.05, x_step=50, x_stepsize=0.2, x_noise_std=0.05, x_noise_anneal=1., x_bound=(0, 1), z_bound=None, x_clip_langevin_grad=None)
model.cuda(device);
opt = Adam(model.parameters(), lr=0.0001)

In [5]:
print('starting autoencoder pre-training...')
n_epoch = n_ae_epoch; l_ae_result = []
i = 0
for i_epoch in range(n_epoch):
    for x, _ in in_train_dl:
        x = x.reshape(-1,784).cuda(device)
        d_result = model.train_step_ae(x, opt, clip_grad=clip_grad)

        if i % 50 == 0:
            '''val recon error'''
            val_err = predict(model, in_val_dl, device, flatten=True)
            
            in_pred = predict(model, in_test_dl, device, True)
            out_pred = predict(model, out_test_dl, device, True)
            auc = roc_btw_arr(out_pred, in_pred)
            print('{} epoch {} iterations - AUROC {}'.format(i_epoch, i+1, auc))
        i += 1

starting autoencoder pre-training...
0 epoch 1 iterations - AUROC 0.10088388977094001
0 epoch 51 iterations - AUROC 0.11120175118207275
0 epoch 101 iterations - AUROC 0.1581880930551518
0 epoch 151 iterations - AUROC 0.15235393357533833
0 epoch 201 iterations - AUROC 0.16213123430011106
0 epoch 251 iterations - AUROC 0.2054665802008095
0 epoch 301 iterations - AUROC 0.2545646767096263
0 epoch 351 iterations - AUROC 0.25127922260237384


In [6]:
model.decoder.fc1.weight

Parameter containing:
tensor([[-0.0427,  0.0254, -0.0244,  ...,  0.0006, -0.0167,  0.0366],
        [-0.0274, -0.0375, -0.0275,  ...,  0.0166, -0.0295, -0.0278],
        [ 0.0396, -0.0309,  0.0167,  ..., -0.0298,  0.0368,  0.0342],
        ...,
        [ 0.0439,  0.0182,  0.0234,  ..., -0.0309,  0.0368,  0.0140],
        [ 0.0036, -0.0125,  0.0380,  ..., -0.0372, -0.0394,  0.0342],
        [ 0.0147, -0.0545, -0.0011,  ...,  0.0270, -0.0413, -0.0039]],
       device='cuda:1', requires_grad=True)

In [7]:
in_pred = predict(model, in_test_dl, device, flatten=True)
out_pred = predict(model, out_test_dl, device, flatten=True)
auc = roc_btw_arr(out_pred, in_pred)
print(f'[step1 model][vs{leave_out} AUC]: {auc}')

[step1 model][vs1 AUC]: 0.24598234406951058


In [8]:
encoder = FC_supermask_encode_nonstochastic(device = device, sparsity = pruning_ratio,previous_model=model)
decoder = FC_supermask_decode_nonstochastic(device = device, sparsity = pruning_ratio,previous_model=model)
        
new_model = NAE(encoder, decoder, l2_norm_reg=l2_norm_reg, l2_norm_reg_en=l2_norm_reg_en, spherical=spherical, z_step=10, z_stepsize=0.2, z_noise_std=0.05, x_step=50, x_stepsize=0.2, x_noise_std=0.05, x_noise_anneal=1., x_bound=(0, 1), z_bound=None, x_clip_langevin_grad=None)

opt = Adam(new_model.parameters(), lr=0.00001)

new_model.cuda(device)

hi
hi
hi
hi
hi
hi
hi
hi


NAE(
  (encoder): FC_supermask_encode_nonstochastic(
    (fc1): MaskedLinear_nonstochastic()
    (fc2): MaskedLinear_nonstochastic()
    (fc3): MaskedLinear_nonstochastic()
    (fc4): MaskedLinear_nonstochastic()
  )
  (decoder): FC_supermask_decode_nonstochastic(
    (fc4): MaskedLinear_nonstochastic()
    (fc3): MaskedLinear_nonstochastic()
    (fc2): MaskedLinear_nonstochastic()
    (fc1): MaskedLinear_nonstochastic()
  )
)

In [9]:
new_model.decoder.fc1.fcw

Parameter containing:
tensor([[-0.0427,  0.0254, -0.0244,  ...,  0.0006, -0.0167,  0.0366],
        [-0.0274, -0.0375, -0.0275,  ...,  0.0166, -0.0295, -0.0278],
        [ 0.0396, -0.0309,  0.0167,  ..., -0.0298,  0.0368,  0.0342],
        ...,
        [ 0.0439,  0.0182,  0.0234,  ..., -0.0309,  0.0368,  0.0140],
        [ 0.0036, -0.0125,  0.0380,  ..., -0.0372, -0.0394,  0.0342],
        [ 0.0147, -0.0545, -0.0011,  ...,  0.0270, -0.0413, -0.0039]],
       device='cuda:1')

In [10]:
in_pred = predict(new_model, in_test_dl, device, flatten=True)
out_pred = predict(new_model, out_test_dl, device, flatten=True)
auc = roc_btw_arr(out_pred, in_pred)
print(f'[Transfered Model][vs{leave_out} AUC]: {auc}')



[Transfered Model][vs1 AUC]: 0.24598234406951058
