In [1]:
import os, sys, torch
sys.path.append(os.path.abspath('../modules'))
sys.path.append(os.path.abspath('../modules/mnist'))
import vae_train as vt
import vae_ortho as vo
import vae_surgery as vs
import datapipe
import classifier as cl
from vae import VAE
import utility as ut

folder = '../data/MNIST/vae'
epochs = 100
batch_size = 100
latent_dim = 2
device = torch.device("cuda" if torch.cuda.is_available() else "mps")

**Train a VAE from scratch**

In [2]:
vt.train(None, folder, 500, 1000, latent_dim, device, log_interval=1)

Epochs: 100%|█████████████████████████████████| 500/500 [17:02<00:00,  2.05s/it]


**Train a VAE with uniformity loss**

In [3]:
folder = '../data/MNIST/vaeu'
vt.train_with_uniformity(None, folder, 500, 1000, latent_dim, device, log_interval=1, uniformity_weight=1e5)

Epochs: 100%|█████████████████████████████████| 500/500 [15:38<00:00,  1.88s/it]


**Retrain a VAE with $L_{\text{rest}}$ + orthogonal loss**

In [3]:
net_path = '../data/MNIST/vae/checkpoints/vae_500.pth'
folder = '../data/MNIST/vae-o-rest'
orthogonality_factor = 10.

model = VAE(device=device)
model.load_state_dict(torch.load(net_path))
model.to(device)

vo.train(model, folder, 500, 100, latent_dim, device, orthogonality_factor, log_interval=1)

Epochs: 100%|█████████████████████████████████| 500/500 [17:24<00:00,  2.09s/it]


**Retrain a VAE with $L_{\text{all}}$ + orthogonal loss**

In [5]:
net_path = '../data/MNIST/vae/checkpoints/vae_500.pth'
folder = '../data/MNIST/vae-o-all'
orthogonality_factor = 10.

model = VAE(device=device)
model.load_state_dict(torch.load(net_path))
model.to(device)

vo.train(model, folder, 500, 100, latent_dim, device, orthogonality_factor, log_interval=1, one_weight=1.)

Epochs: 100%|█████████████████████████████████| 500/500 [18:56<00:00,  2.27s/it]


**Retrain a VAE with $L_{\text{all}}$ + orthogonal loss**

In [9]:
net_path = '../data/MNIST/vae/checkpoints/vae_500.pth'
folder = '../data/MNIST/vae-o-all-long'
orthogonality_factor = 10.

model = VAE(device=device)
model.load_state_dict(torch.load(net_path))
model.to(device)

vo.train(model, folder, 2000, 100, latent_dim, device, orthogonality_factor, log_interval=1, one_weight=1.)

Epochs: 100%|█████████████████████████████| 2000/2000 [1:09:10<00:00,  2.08s/it]


**Retrain a VAE with $L_{\text{rest}}$**

In [6]:
net_path = '../data/MNIST/vae/checkpoints/vae_500.pth'
folder = '../data/MNIST/vae-retrain'
orthogonality_factor = 0.

model = VAE(device=device)
model.load_state_dict(torch.load(net_path))
model.to(device)

vo.train(model, folder, 500, 100, latent_dim, device, orthogonality_factor, log_interval=1)

Epochs: 100%|█████████████████████████████████| 500/500 [17:08<00:00,  2.06s/it]


**Perform surgery on vae-o-rest**

In [7]:
net_path = '../data/MNIST/vae-o-rest/checkpoints/vae_1.pth'
folder = '../data/MNIST/vae-o-rest-s'

model = VAE(device=device)
model.load_state_dict(torch.load(net_path))
model.to(device)

vs.operate(model, folder, 100, 100, latent_dim, device=device, log_interval=1)

Epochs: 100%|█████████████████████████████████| 100/100 [07:22<00:00,  4.42s/it]


**Perform surgery on vae-o-all**

In [8]:
net_path = '../data/MNIST/vae-o-all/checkpoints/vae_1.pth'
folder = '../data/MNIST/vae-o-all-s'

model = VAE(device=device)
model.load_state_dict(torch.load(net_path))
model.to(device)

vs.operate(model, folder, 100, 100, latent_dim, device=device, log_interval=1)

Epochs: 100%|█████████████████████████████████| 100/100 [07:17<00:00,  4.38s/it]


**Perfom surgery on vae**

In [6]:
net_path = '../data/MNIST/vae/checkpoints/vae_500.pth'
folder = '../data/MNIST/vae-s'

model = VAE(device=device)
model.load_state_dict(torch.load(net_path))
model.to(device)

vs.operate(model, folder, 500, 100, latent_dim, device=device, log_interval=1)

Epochs: 100%|█████████████████████████████████| 500/500 [36:56<00:00,  4.43s/it]


In [14]:
dataloader = datapipe.MNIST().get_dataloader(batch_size)
categorizer = cl.get_classifier(device=device)
bin = 0
for img, label in dataloader:
    bin += torch.bincount(torch.argmax(torch.softmax(categorizer(img.to(device)), dim=1), dim=1))/batch_size
bin /= len(dataloader)
print(bin)

tensor([0.0987, 0.1123, 0.0993, 0.1022, 0.0974, 0.0904, 0.0986, 0.1044, 0.0975,
        0.0992], device='mps:0')
