In [1]:
#!/scratch/sagar/Projects/radio_map_deep_prior/venv/bin/python3.6

from collections import OrderedDict, namedtuple
from itertools import product
import os
from tqdm import tqdm, trange
from IPython.display import clear_output
import time
import torch.nn as nn
import pandas as pd
import torch
from models.aae import AAE
import matplotlib.pyplot as plt
from dataset import MNISTDataset, MNIST_mean, MNIST_std
from utils.plot_utils import show_latent
from utils.run_manager import RunBuilder


if torch.cuda.is_available():
    devices = ['cuda']
else:
    devices = ['cpu']
print('starting')

params = OrderedDict(
    d_lr = [0.001],
    ae_lr = [0.001],
    batch_size = [32],
    device = devices,
    shuffle = [True],
    num_workers = [5],
    beta = [10],
    z_dim = [9], 
    manual_seed = [1265],
    loss_func = [nn.MSELoss]
)


train_set = MNISTDataset(path='data/MNIST/processed', normalize=True)
val_set = MNISTDataset(path='data/MNIST/processed', train=False, normalize=True)

real_label = 0.9
fake_label = 0

criterion_adv = nn.BCELoss()
alpha = 0.0001
Tc = 0
Td = 0
T_train = 100

count=0
run_data = []

for run in RunBuilder.get_runs(params):
    device = torch.device(run.device)
#     aae = AAE(latent_dim=run.z_dim).to(run.device)
    try:
        aae = torch.load('trained_models/mnist_aae_jun7_z_9.model')
        print('loaded successfully')
    except:
        aae = AAE(latent_dim=run.z_dim).to(run.device)

    
    loader = torch.utils.data.DataLoader(train_set, 
                                         batch_size=run.batch_size, 
                                         shuffle=run.shuffle, 
                                         num_workers=run.num_workers
                                        )
    
    val_loader = torch.utils.data.DataLoader(val_set, 
                                         batch_size=run.batch_size, 
                                         shuffle=run.shuffle, 
                                         num_workers=run.num_workers
                                        )
    
    optimizerAE = torch.optim.Adam(aae.autoencoder.parameters(), lr=run.ae_lr)
    optimizerDiscriminator = torch.optim.Adam(aae.discriminator.parameters(), lr=run.d_lr)

    num_batches = len(train_set)/run.batch_size
    criterion_recons = run.loss_func()
    
    num_val_batches = len(val_set)/run.batch_size
    
    for epoch in range(27):

        total_adv_loss = 0
 
        real_count = 0
        fake_count = 0
        
        total_D_real = 0
        total_D_fake = 0
        
        total_D_loss = 0
        total_G_loss = 0
        
        total_recons_loss = 0
        
        num_batches = len(train_set)/run.batch_size
        i=0
        for batch in tqdm(loader):
            i+=1

            # Get data
            image = batch
            image = image.to(run.device)
            
            b_size = image.size(0)
            labels_real = torch.full((b_size,1), real_label, device=run.device, dtype=torch.float32)
            labels_fake = torch.full((b_size,1), fake_label, device=run.device, dtype=torch.float32)
            
            sample_real = torch.randn((b_size, run.z_dim), dtype=torch.float32)
            sample_real = sample_real.to(run.device)
            
            # update autoencoder
            optimizerAE.zero_grad()
            out = aae.autoencoder(image)
            loss = criterion_recons(out, image)
            loss.backward()
            optimizerAE.step()
            
            total_recons_loss += loss.item()
            
            
            # Update Generator
            optimizerAE.zero_grad()

            fake = aae.autoencoder.encoder(image)
            fake_pred = aae.discriminator(fake)
            gen_loss = criterion_adv(fake_pred, labels_real)
            gen_loss.backward()
            optimizerAE.step()
            
            # Update Discriminator
            optimizerDiscriminator.zero_grad()
            
            sample_fake = fake.detach().clone()
            
            real_loss = criterion_adv(aae.discriminator(sample_real), labels_real)
            fake_loss = criterion_adv(aae.discriminator(sample_fake), labels_fake)
            d_loss = 0.5*(real_loss + fake_loss)
            d_loss.backward()
            optimizerDiscriminator.step()
            
            total_D_real += real_loss.item()
            total_D_fake += fake_loss.item()
        
            total_D_loss += d_loss.item()
            total_G_loss += gen_loss.item()

            
        results = OrderedDict()
        results['recons_loss'] = total_recons_loss/num_batches
        results['G_loss'] = total_G_loss/num_batches
        results['D_loss'] = total_D_loss/num_batches
        results['d_lr'] = run.d_lr
        results['g_lr'] = run.ae_lr
        results['batch_size'] = [run.batch_size]
        run_data.append(results)
        df = pd.DataFrame.from_dict(run_data, orient='columns')
        clear_output(wait=True)
        display(df)
        
        torch.save(aae, 'trained_models/mnist_aae_jun7_z_9.model')


Unnamed: 0,recons_loss,G_loss,D_loss,d_lr,g_lr,batch_size
0,0.105319,0.78332,0.68699,0.001,0.001,[32]
1,0.105099,0.783045,0.687215,0.001,0.001,[32]
2,0.104804,0.782895,0.687283,0.001,0.001,[32]
3,0.104422,0.782868,0.687289,0.001,0.001,[32]
4,0.104031,0.782825,0.687157,0.001,0.001,[32]
5,0.103937,0.782636,0.687256,0.001,0.001,[32]
6,0.103852,0.782755,0.687221,0.001,0.001,[32]
7,0.103509,0.782722,0.687226,0.001,0.001,[32]
8,0.10339,0.783081,0.687037,0.001,0.001,[32]
9,0.103297,0.782766,0.687312,0.001,0.001,[32]


## Train semi-supervised AAE

In [3]:
#!/scratch/sagar/Projects/radio_map_deep_prior/venv/bin/python3.6

from collections import OrderedDict, namedtuple
from itertools import product
import os
from tqdm import tqdm, trange
from IPython.display import clear_output
import time
import torch.nn as nn
import pandas as pd
import torch
from models.aae import AAESemiSupervised
import matplotlib.pyplot as plt
from dataset import MNISTDataset, MNIST_mean, MNIST_std, MNISTSupervisedDataset, GMM, OneHot
from utils.plot_utils import show_latent
from utils.run_manager import RunBuilder


if torch.cuda.is_available():
    devices = ['cuda']
else:
    devices = ['cpu']
print('starting')

params = OrderedDict(
    d_lr = [0.001],
    ae_lr = [0.001],
    batch_size = [32],
    device = devices,
    shuffle = [True],
    num_workers = [5],
    beta = [10],
    z_dim = [2], 
    manual_seed = [1265],
    loss_func = [nn.MSELoss]
)


train_set = MNISTSupervisedDataset(path='data/MNIST/processed', normalize=True)
val_set = MNISTSupervisedDataset(path='data/MNIST/processed', train=False, normalize=True)

real_label = 0.9
fake_label = 0

criterion_adv = nn.BCELoss()
alpha = 0.0001
Tc = 0
Td = 0
T_train = 100

count=0
run_data = []


for run in RunBuilder.get_runs(params):
    device = torch.device(run.device)
    aae = AAESemiSupervised().to(run.device)

    
    loader = torch.utils.data.DataLoader(train_set, 
                                         batch_size=run.batch_size, 
                                         shuffle=run.shuffle, 
                                         num_workers=run.num_workers
                                        )
    
    val_loader = torch.utils.data.DataLoader(val_set, 
                                         batch_size=run.batch_size, 
                                         shuffle=run.shuffle, 
                                         num_workers=run.num_workers
                                        )
    
    gmm_means = GMM(radius=100, num_classes = 10)
    one_hot = OneHot(num_classes=10)
    
    optimizerAE = torch.optim.Adam(aae.autoencoder.parameters(), lr=run.ae_lr)
    optimizerDiscriminator = torch.optim.Adam(aae.discriminator.parameters(), lr=run.d_lr)

    num_batches = len(train_set)/run.batch_size
    criterion_recons = run.loss_func()
    
    num_val_batches = len(val_set)/run.batch_size
    
    for epoch in range(100):

        total_adv_loss = 0
 
        real_count = 0
        fake_count = 0
        
        total_D_real = 0
        total_D_fake = 0
        
        total_D_loss = 0
        total_G_loss = 0
        
        total_recons_loss = 0
        
        num_batches = len(train_set)/run.batch_size
        i=0
        for batch in tqdm(loader):
            i+=1

            # Get data
            image, labels = batch
            image = image.to(run.device)
            labels = labels.to(run.device)
            
            b_size = image.size(0)
            labels_real = torch.full((b_size,1), real_label, device=run.device, dtype=torch.float32)
            labels_fake = torch.full((b_size,1), fake_label, device=run.device, dtype=torch.float32)
            
            sample_real = gmm_means[labels] + torch.randn((b_size, run.z_dim), dtype=torch.float32)
            sample_real = torch.cat((sample_real, one_hot[labels]), dim=1)
            sample_real = sample_real.to(run.device)
            
            # update autoencoder
            optimizerAE.zero_grad()
            out = aae.autoencoder(image)
            loss = criterion_recons(out, image)
            loss.backward()
            optimizerAE.step()
            
            total_recons_loss += loss.item()
            
            
            # Update Generator
            optimizerAE.zero_grad()

            fake = aae.autoencoder.encoder(image)
            fake = torch.cat((fake, one_hot[labels].to(device)), dim=1)
            fake_pred = aae.discriminator(fake)
            gen_loss = criterion_adv(fake_pred, labels_real)
            gen_loss.backward()
            optimizerAE.step()
            
            # Update Discriminator
            optimizerDiscriminator.zero_grad()
            
            sample_fake = fake.detach().clone()
            
            real_loss = criterion_adv(aae.discriminator(sample_real), labels_real)
            fake_loss = criterion_adv(aae.discriminator(sample_fake), labels_fake)
            d_loss = 0.5*(real_loss + fake_loss)
            d_loss.backward()
            optimizerDiscriminator.step()
            
            total_D_real += real_loss.item()
            total_D_fake += fake_loss.item()
        
            total_D_loss += d_loss.item()
            total_G_loss += gen_loss.item()

            
        results = OrderedDict()
        results['recons_loss'] = total_recons_loss/num_batches
        results['G_loss'] = total_G_loss/num_batches
        results['D_loss'] = total_D_loss/num_batches
        results['d_lr'] = run.d_lr
        results['g_lr'] = run.ae_lr
        results['batch_size'] = [run.batch_size]
        run_data.append(results)
        df = pd.DataFrame.from_dict(run_data, orient='columns')
        clear_output(wait=True)
        display(df)
        
        torch.save(aae, 'trained_models/mnist_aae_semisupervised_r_100_2.model')


Unnamed: 0,recons_loss,G_loss,D_loss,d_lr,g_lr,batch_size
0,0.609448,1.419309,0.449500,0.001,0.001,[32]
1,0.573456,1.443974,0.476211,0.001,0.001,[32]
2,0.566959,1.498440,0.479951,0.001,0.001,[32]
3,0.564397,1.591559,0.467816,0.001,0.001,[32]
4,0.562942,1.685770,0.463018,0.001,0.001,[32]
...,...,...,...,...,...,...
95,0.544475,1.136931,0.645106,0.001,0.001,[32]
96,0.544379,1.159762,0.640798,0.001,0.001,[32]
97,0.544804,1.148005,0.648952,0.001,0.001,[32]
98,0.544339,1.182510,0.637769,0.001,0.001,[32]


In [91]:
gmm_means = GMM(radius=7, num_classes = 10)
one_hot = OneHot(num_classes=10)

labels = torch.tensor([0,1,2,3, 4, 5, 6, 7, 8, 9])
sample_real = gmm_means[labels] + torch.randn((10, 2), dtype=torch.float32)
sample_real = torch.cat((sample_real, one_hot[labels]), dim=1)

In [92]:
sample_real

tensor([[ 7.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 5.6631e+00,  4.1145e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 2.1631e+00,  6.6574e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [-2.1631e+00,  6.6574e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [-5.6631e+00,  4.1145e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  1.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [-7.0000e+00,  8.5725e-16,  0.0000e+00,  0.0000e+00,  0.0000e+00,
      

In [1]:
from dataset import MNISTSupervisedDataset



In [76]:
class GMM:
    def __init__(self, radius=7, num_classes=10):
        self.means = []
        angles = [2*np.pi/10*i for i in range(num_classes)]
        for theta in angles:
            self.means.append([radius*np.cos(theta), radius*np.sin(theta)])
        
    def __getitem__(self, indices: torch.Tensor):
        mapping = map(self.means.__getitem__, indices)
        accessed_list = torch.tensor(list(mapping), dtype=torch.float32)
        return accessed_list
    
class OneHot:
    def __init__(self, num_classes=10):
        self.encoding_vecs = np.eye(num_classes)
        
    def __getitem__(self, indices: torch.Tensor):
        mapping = map(self.encoding_vecs.__getitem__, indices)
        accessed_list = torch.tensor(list(mapping), dtype=torch.float32)
        return accessed_list
    

In [77]:
one_hot = OneHot()

In [79]:
one_hot[torch.tensor([1,3])]

tensor([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]])

In [75]:
list(np.eye(10))

[array([1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 array([0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]),
 array([0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]),
 array([0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]),
 array([0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]),
 array([0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]),
 array([0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]),
 array([0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]),
 array([0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]),
 array([0., 0., 0., 0., 0., 0., 0., 0., 0., 1.])]

In [68]:
a = torch.eye(10)

In [69]:
a

tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])

In [73]:
b = torch.tensor([4,4])
torch.cat((b,a[3]))

tensor([4., 4., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.])

In [80]:
a = GMM()

In [86]:
a[torch.tensor([0,1, 2,3, 4, 5, 6, 7, 8, 9])]

tensor([[ 7.0000e+00,  0.0000e+00],
        [ 5.6631e+00,  4.1145e+00],
        [ 2.1631e+00,  6.6574e+00],
        [-2.1631e+00,  6.6574e+00],
        [-5.6631e+00,  4.1145e+00],
        [-7.0000e+00,  8.5725e-16],
        [-5.6631e+00, -4.1145e+00],
        [-2.1631e+00, -6.6574e+00],
        [ 2.1631e+00, -6.6574e+00],
        [ 5.6631e+00, -4.1145e+00]])

In [None]:
list()

In [60]:
a =torch.tensor([1,1])

In [45]:
list(a)

[tensor(5), tensor(5)]

In [31]:
angles = [2*np.pi/10*i for i in range(10)]

In [62]:
a_list = [[1,2], 2, 3]

indices_to_access = [0, 2, 0, 1, 1]


accessed_mapping = map(a_list.__getitem__, indices_to_access)

accessed_list = list(accessed_mapping)


print(accessed_list)

[[1, 2], 3, [1, 2], 2, 2]


In [39]:
torch.tensor(accessed_mapping, dtype=torch.float32)

TypeError: must be real number, not map

In [30]:
2*np.pi

6.283185307179586

In [27]:
import numpy as np

In [28]:
np.pi

3.141592653589793

In [23]:
torch.arccos

<function _VariableFunctionsClass.arccos>

In [2]:
train_set = MNISTSupervisedDataset(path='data/MNIST/processed', normalize=True)

  self.dataset = torch.tensor(self.dataset, dtype=torch.float32)


In [3]:
a,b = train_set[0]

In [5]:
a.shape

torch.Size([1, 28, 28])

In [6]:
b.shape

torch.Size([])

In [7]:
b

tensor(5)

In [9]:
import torch

In [10]:
loader = torch.utils.data.DataLoader(train_set, 
                                         batch_size=15)

In [11]:
a = enumerate(loader)

In [13]:
c,b = next(a)

In [19]:
r,t = b

In [67]:
t

tensor([7, 2, 8, 6, 9, 4, 0, 9, 1, 1, 2, 4, 3, 2, 7])