In [1]:
import os, sys, torch
sys.path.append(os.path.abspath('../modules/mnist'))
sys.path.append(os.path.abspath('../modules'))
import vae_train as vt
import vae_ortho as vo
import vae_surgery as vs
import vae 
import utility as ut

folder = '../data/MNIST/vae'
epochs = 100
batch_size = 100
latent_dim = 2
device = torch.device("cuda" if torch.cuda.is_available() else "mps")

**Retrain a LoRA_VAE_Decoder with $L_{\text{rest}}$ + orthogonal loss**

In [10]:
lora_r = 50
net_path = "../data/MNIST/vae/checkpoints/vae_500.pth"
folder = f"../data/MNIST/vae-o-rest-ld-r{lora_r}"
orthogonality_factor = 10.

model = vae.VAE(device=device)
model.load_state_dict(torch.load(net_path))
lora_model = vae.LoRA_VAE_Decoder(ut.freeze_all(model), lora_r=lora_r).to(device)

print(ut.count_params(model))
print(ut.count_params(lora_model))
vo.train(lora_model, folder, 500, 100, latent_dim, device, orthogonality_factor, log_interval=1, one_weight=0.)

(632788, 0)
(712088, 79300)


Epochs:   4%|█▎                                | 19/500 [00:28<12:12,  1.52s/it]


KeyboardInterrupt: 

**Retrain a LoRA_VAE_Decoder with $L_{\text{all}}$ + orthogonal loss**

In [14]:
lora_r = 10
net_path = "../data/MNIST/vae/checkpoints/vae_500.pth"
folder = f"../data/MNIST/vae-o-all-ld-r{lora_r}"
orthogonality_factor = 10.

model = vae.VAE(device=device)
model.load_state_dict(torch.load(net_path))
lora_model = vae.LoRA_VAE_Decoder(ut.freeze_all(model), lora_r=lora_r).to(device)

print(ut.count_params(model))
print(ut.count_params(lora_model))
vo.train(lora_model, folder, 500, 100, latent_dim, device, orthogonality_factor, log_interval=1, one_weight=1.)

(632788, 0)
(648648, 15860)


Epochs: 100%|█████████████████████████████████| 500/500 [12:01<00:00,  1.44s/it]


**Perform surgury on vae-o-rest-ld-r***

In [None]:
lora_r = 10
net_path_0 = "../data/MNIST/vae/checkpoints/vae_500.pth"
net_path_1 = f"../data/MNIST/vae-o-rest-ld-r{lora_r}/checkpoints/vae_1.pth"
folder = f"../data/MNIST/vae-o-rest-ld-r{lora_r}-s"

model = vae.VAE(device=device)
model.load_state_dict(torch.load(net_path_0))
model.to(device)
lora_model = vae.LoRA_VAE_Decoder(ut.freeze_all(model), lora_r=lora_r).to(device)
lora_model.load_state_dict(torch.load(net_path_1))

vs.operate(lora_model, folder, 100, 100, latent_dim, device=device, log_interval=1)

**Perform surgery on vae-o-all-ld-r***

In [15]:
lora_r = 10
net_path_0 = "../data/MNIST/vae/checkpoints/vae_500.pth"
net_path_1 = f"../data/MNIST/vae-o-all-ld-r{lora_r}/checkpoints/vae_1.pth"
folder = f"../data/MNIST/vae-o-all-ld-r{lora_r}-s"

model = vae.VAE(device=device)
model.load_state_dict(torch.load(net_path_0))
model.to(device)
lora_model = vae.LoRA_VAE_Decoder(ut.freeze_all(model), lora_r=lora_r).to(device)
lora_model.load_state_dict(torch.load(net_path_1))

vs.operate(lora_model, folder, 100, 100, latent_dim, device=device, log_interval=1)

Epochs: 100%|█████████████████████████████████| 100/100 [07:37<00:00,  4.58s/it]


In [None]:
net_path = '../data/MNIST/vae-o-last/checkpoints/vae_500.pth'
vae.generate_random_samples(net_path, latent_dim, num_samples=169, device=device)

In [None]:
net_path = '../data/MNIST/vae/checkpoints/vae_100.pth'
folder = '../data/MNIST/vae-o-first'
orthogonality_factor = 10.

model = vae.VAE(latent_dim = latent_dim, device=device).to(device)
model.load_state_dict(torch.load(net_path))
vae.train_orthogonal(ut.freeze_all_but_first(model), folder, 500, batch_size, latent_dim, device,\
                    orthogonality_factor)

In [None]:
net_path = '../data/MNIST/vae-o-first/checkpoints/vae_500.pth'
vae.generate_random_samples(net_path, latent_dim, num_samples=169, device=device)

In [3]:
lora_r = 100
net_path = '../data/MNIST/vae/checkpoints/vae_100.pth'
folder = f'../data/MNIST/vae-o-lora-rank-{lora_r}'
orthogonality_factor = 10.

model = vae.VAE(latent_dim = latent_dim, device=device).to(device)
model.load_state_dict(torch.load(net_path))

lora_model = vae.LoRA_VAE(ut.freeze_all(model), lora_r=lora_r).to(device)
vae.train_orthogonal(lora_model, folder, 500, batch_size, latent_dim, device, orthogonality_factor)

Epochs: 100%|█████████████████████████████████| 500/500 [15:26<00:00,  1.85s/it]


In [2]:
net_path = '../data/MNIST/vae/checkpoints/vae_100.pth'
folder = '../data/MNIST/vae-o-last'
orthogonality_factor = 10.

model = vae.VAE(latent_dim = latent_dim, device=device).to(device)
model.load_state_dict(torch.load(net_path))
model.freeze_encoder()
ut.count_params(model)

(632788, 317988)

In [2]:
lora_r = 100

net_path = '../data/MNIST/vae/checkpoints/vae_100.pth'
model = vae.VAE(latent_dim = latent_dim, device=device).to(device)
model.load_state_dict(torch.load(net_path))
lora_model = vae.LoRA_VAE(ut.freeze_all(model), lora_r=lora_r).to(device)

net_path_lora = f'../data/MNIST/vae-o-lora-rank-{lora_r}/checkpoints/vae_100.pth'
lora_model.load_state_dict(torch.load(net_path_lora))
folder = folder = f'../data/MNIST/vae-o-lora-rank-{lora_r}-s'
vae.operate(lora_model, folder, 100, batch_size, latent_dim, device)

Epochs: 100%|█████████████████████████████████| 100/100 [06:23<00:00,  3.83s/it]
