In [1]:
#imports
import argparse
import torch
import torch.utils.data
import pickle
import numpy as np
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import tables

In [2]:
# Training settings
cuda = torch.cuda.is_available()

seed = 10


kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
imagesize ={'x':100,'y':100,'c':3}
n_classes = 10
X_dim = imagesize['x']*imagesize['y']*imagesize['c']
z_dim = 1000
y_dim = 10
train_batch_size = 20
valid_batch_size = 20
N = 5000
epochs = 500 #500

In [3]:
##################################
# Load data and create Data loaders
##################################
def load_data(data_path='../data/'):
    print('loading data!')
    filepath = "data/train_data"
    hdf5_file = tables.open_file(filepath, "r")
    # To access images array:
    trainset_unlabeled = hdf5_file.root.images
    #hdf5_file = tables.open_file(filepath, "r")
    #for i in range(len(trainset_unlabeled)):
    #    plt.imshow(trainset_unlabeled[i])
    #    plt.show()
    #print(len(trainset_unlabeled))
    #print(np.shape(trainset_unlabeled[0]))
    #trainset_unlabeled = pickle.load(open(data_path + "train_unlabeled.p", "rb"))
    # Set -1 as labels for unlabeled data
    trainset_unlabeled.train_labels = hdf5_file.root.labels
    filepath = "data/test_data"
    hdf5_file = tables.open_file(filepath, "r")
    # To access images array:
    validset_unlabeled = hdf5_file.root.images
    trainset_unlabeled.labels = hdf5_file.root.labels
    
    train_unlabeled_loader = torch.utils.data.DataLoader(trainset_unlabeled,
                                                         batch_size=train_batch_size,
                                                         shuffle=True, **kwargs)

    return train_unlabeled_loader,trainset_unlabeled

In [4]:

##################################
# Define Networks
##################################
# Encoder
class Q_net(nn.Module):
    def __init__(self):
        super(Q_net, self).__init__()
        self.lin1 = nn.Linear(X_dim, N)
        self.lin2 = nn.Linear(N, N)
        # Gaussian code (z)
        self.lin3gauss = nn.Linear(N, z_dim)

    def forward(self, x):
        x = F.dropout(self.lin1(x), p=0.2, training=self.training)
        x = F.relu(x)
        x = F.dropout(self.lin2(x), p=0.2, training=self.training)
        x = F.relu(x)
        xgauss = self.lin3gauss(x)

        return xgauss


In [5]:

# Decoder
class P_net(nn.Module):
    def __init__(self):
        super(P_net, self).__init__()
        self.lin1 = nn.Linear(z_dim, N)
        self.lin2 = nn.Linear(N, N)
        self.lin3 = nn.Linear(N, X_dim)

    def forward(self, x):
        x = self.lin1(x)
        x = F.dropout(x, p=0.2, training=self.training)
        x = F.relu(x)
        x = self.lin2(x)
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.lin3(x)
        return F.sigmoid(x)

In [6]:

class D_net_gauss(nn.Module):
    def __init__(self):
        super(D_net_gauss, self).__init__()
        self.lin1 = nn.Linear(z_dim, N)
        self.lin2 = nn.Linear(N, N)
        self.lin3 = nn.Linear(N, 1)

    def forward(self, x):
        x = F.dropout(self.lin1(x), p=0.2, training=self.training)
        x = F.relu(x)
        x = F.dropout(self.lin2(x), p=0.2, training=self.training)
        x = F.relu(x)

        return F.sigmoid(self.lin3(x))

In [7]:

####################
# Utility functions
####################
def save_model(model, filename):
    print('Best model so far, saving it...')
    torch.save(model.state_dict(), filename)


def report_loss(epoch, D_loss_gauss, G_loss, recon_loss,valid_recon_loss):
    '''
    Print loss
    '''
    print('Epoch-{}; D_loss_gauss: {:.4}; G_loss: {:.4}; recon_loss: {:.4}; valid_recon_loss: {:.4}'.format(epoch,
                                                                                   D_loss_gauss.data[0],
                                                                                   G_loss.data[0],
                                                                                   recon_loss.data[0],valid_recon_loss.data[0]))


def create_latent(Q, loader):
    '''
    Creates the latent representation for the samples in loader
    return:
        z_values: numpy array with the latent representations
        labels: the labels corresponding to the latent representations
    '''
    Q.eval()
    labels = []

    for batch_idx, (X, target) in enumerate(loader):

        X = X * 0.3081 + 0.1307
        # X.resize_(loader.batch_size, X_dim)
        X, target = Variable(X), Variable(target)
        labels.extend(target.data.tolist())
        if cuda:
            X, target = X.cuda(), target.cuda()
        # Reconstruction phase
        z_sample = Q(X)
        if batch_idx > 0:
            z_values = np.concatenate((z_values, np.array(z_sample.data.tolist())))
        else:
            z_values = np.array(z_sample.data.tolist())
    labels = np.array(labels)

    return z_values, labels


In [8]:

       
####################
# Train procedure
####################
def train(P, Q, D_gauss, P_decoder, Q_encoder, Q_generator, D_gauss_solver, data_loader,valid_data,update_disc = True):
    '''
    Train procedure for one epoch.
    '''
    TINY = 1e-15
    # Set the networks in train mode (apply dropout when needed)
    Q.train()
    P.train()
    D_gauss.train()

    # Loop through the labeled and unlabeled dataset getting one batch of samples from each
    # The batch size has to be a divisor of the size of the dataset or it will return
    # invalid samples
    last_g_loss = 100
    last_g_grad = 1
    v_size = len(valid_data)
    v_data = np.zeros((v_size,imagesize['x'],imagesize['y'],imagesize['c']))
    for i in range(v_size):
        v_data[i] = np.array(valid_data[i])
    v_data = 1 - v_data/255
    V = Variable(torch.FloatTensor(v_data).resize_(valid_batch_size, X_dim))
    if cuda:
            V= V.cuda()
    for X in data_loader:
        # Load batch and normalize samples to be between 0 and 1
        X = 1-X/255
        X.resize_(train_batch_size, X_dim)
        X= Variable(X)
        X = X.type(torch.FloatTensor)
        if cuda:
            X= X.cuda()

        # Init gradients
        P.zero_grad()
        Q.zero_grad()
        D_gauss.zero_grad()

        #######################
        # Reconstruction phase
        #######################
        z_sample = Q(X)
        X_sample = P(z_sample)
        recon_loss = F.binary_cross_entropy(X_sample + TINY, X.resize(train_batch_size, X_dim) + TINY)
        recon_loss.backward()
        P_decoder.step()
        Q_encoder.step()

        P.zero_grad()
        Q.zero_grad()
        
        # recon loss for valid set
        z_sample = Q(V)
        X_sample = P(z_sample)
        recon_loss_valid = F.binary_cross_entropy(X_sample + TINY, X.resize(train_batch_size, X_dim) + TINY)
        P.zero_grad()
        Q.zero_grad()
        
        D_gauss.zero_grad()

        #######################
        # Regularization phase
        #######################
        # Discriminator
        Q.eval()
        z_real_gauss = Variable(torch.randn(train_batch_size, z_dim) * 5.)
        if cuda:
            z_real_gauss = z_real_gauss.cuda()

        z_fake_gauss = Q(X)

        D_real_gauss = D_gauss(z_real_gauss)
        D_fake_gauss = D_gauss(z_fake_gauss)
        D_loss = -torch.mean(torch.log(D_real_gauss + TINY) + torch.log(1 - D_fake_gauss + TINY))
        if last_g_loss < D_loss.data[0]: #update_disc:
            #print(last_g_grad,D_loss.data[0])
            #print('updating descriminator')
            D_loss.backward()
            D_gauss_solver.step()

        P.zero_grad()
        Q.zero_grad()
        D_gauss.zero_grad()

        # Generator
        Q.train()
        z_fake_gauss = Q(X)

        D_fake_gauss = D_gauss(z_fake_gauss)
        G_loss = -torch.mean(torch.log(D_fake_gauss + TINY))
        last_g_grad = last_g_loss - G_loss.data[0]
        last_g_loss = G_loss.data[0]
        G_loss.backward()
        Q_generator.step()

        P.zero_grad()
        Q.zero_grad()
        D_gauss.zero_grad()
    return D_loss, G_loss, recon_loss, recon_loss_valid



In [9]:
def generate_model( train_unlabeled_loader,valid_data):
    torch.manual_seed(10)

    if cuda:
        Q = Q_net().cuda()
        P = P_net().cuda()
        D_gauss = D_net_gauss().cuda()
    else:
        Q = Q_net()
        P = P_net()
        D_gauss = D_net_gauss()

    # Set learning rates
    gen_lr = 0.001
    reg_lr = 0.00005

    # Set optimizators
    P_decoder = optim.Adam(P.parameters(), lr=gen_lr)
    Q_encoder = optim.Adam(Q.parameters(), lr=gen_lr)

    Q_generator = optim.Adam(Q.parameters(), lr=reg_lr)
    D_gauss_solver = optim.Adam(D_gauss.parameters(), lr=reg_lr)
    losses = []
    update_disc = True
    for epoch in range(epochs):
        D_loss_gauss, G_loss, recon_loss,recon_loss_valid = train(P, Q, D_gauss, P_decoder, Q_encoder,
                                                 Q_generator,
                                                 D_gauss_solver,
                                                 train_unlabeled_loader,valid_data,update_disc)
        if epoch % 1 == 0:#chaged here!
            report_loss(epoch, D_loss_gauss, G_loss, recon_loss,recon_loss_valid)
            losses.append((epoch, D_loss_gauss.data[0], G_loss.data[0], recon_loss.data[0],recon_loss_valid.data[0]))
            #update_disc = D_loss_gauss.data[0] > G_loss.data[0]
    return Q, P

In [10]:
if __name__ == '__main__':
    train_unlabeled_loader,valid_data = load_data()
    Q, P = generate_model( train_unlabeled_loader,valid_data)

loading data!
Epoch-0; D_loss_gauss: 1.002; G_loss: 1.312; recon_loss: 0.6932; valid_recon_loss: 0.7187
Epoch-1; D_loss_gauss: 1.415; G_loss: 0.6324; recon_loss: 1.118; valid_recon_loss: 0.5964
Epoch-2; D_loss_gauss: 1.889; G_loss: 0.294; recon_loss: 0.5297; valid_recon_loss: 1.856
Epoch-3; D_loss_gauss: 1.511; G_loss: 0.5904; recon_loss: 4.769; valid_recon_loss: 0.6804
Epoch-4; D_loss_gauss: 1.283; G_loss: 0.6616; recon_loss: 1.245; valid_recon_loss: 0.6114
Epoch-5; D_loss_gauss: 1.352; G_loss: 0.6796; recon_loss: 0.5293; valid_recon_loss: 0.5834
Epoch-6; D_loss_gauss: 1.402; G_loss: 0.6654; recon_loss: 0.6137; valid_recon_loss: 0.5796
Epoch-7; D_loss_gauss: 1.338; G_loss: 0.6643; recon_loss: 0.8726; valid_recon_loss: 0.5677
Epoch-8; D_loss_gauss: 1.325; G_loss: 0.6671; recon_loss: 0.4802; valid_recon_loss: 0.6329
Epoch-9; D_loss_gauss: 1.357; G_loss: 0.651; recon_loss: 0.5638; valid_recon_loss: 0.5853
Epoch-10; D_loss_gauss: 1.373; G_loss: 0.6236; recon_loss: 0.4868; valid_recon_loss

Epoch-90; D_loss_gauss: 2.19; G_loss: 0.2504; recon_loss: 0.3497; valid_recon_loss: 0.4527
Epoch-91; D_loss_gauss: 2.124; G_loss: 0.2748; recon_loss: 0.3557; valid_recon_loss: 0.5062
Epoch-92; D_loss_gauss: 2.073; G_loss: 0.2743; recon_loss: 0.3365; valid_recon_loss: 0.5173
Epoch-93; D_loss_gauss: 2.036; G_loss: 0.2541; recon_loss: 0.3425; valid_recon_loss: 0.5182
Epoch-94; D_loss_gauss: 2.015; G_loss: 0.2899; recon_loss: 0.3548; valid_recon_loss: 0.5307
Epoch-95; D_loss_gauss: 2.058; G_loss: 0.2758; recon_loss: 0.342; valid_recon_loss: 0.5375
Epoch-96; D_loss_gauss: 2.025; G_loss: 0.2972; recon_loss: 0.3517; valid_recon_loss: 0.5393
Epoch-97; D_loss_gauss: 2.113; G_loss: 0.2679; recon_loss: 0.3474; valid_recon_loss: 0.4911
Epoch-98; D_loss_gauss: 2.137; G_loss: 0.2354; recon_loss: 0.3388; valid_recon_loss: 0.4704
Epoch-99; D_loss_gauss: 2.134; G_loss: 0.256; recon_loss: 0.3368; valid_recon_loss: 0.4875
Epoch-100; D_loss_gauss: 2.262; G_loss: 0.2613; recon_loss: 0.3367; valid_recon_los

Epoch-179; D_loss_gauss: 2.971; G_loss: 0.1274; recon_loss: 0.2184; valid_recon_loss: 0.595
Epoch-180; D_loss_gauss: 3.008; G_loss: 0.132; recon_loss: 0.2307; valid_recon_loss: 0.6289
Epoch-181; D_loss_gauss: 3.076; G_loss: 0.1068; recon_loss: 0.2467; valid_recon_loss: 0.685
Epoch-182; D_loss_gauss: 3.059; G_loss: 0.1221; recon_loss: 0.2117; valid_recon_loss: 0.519
Epoch-183; D_loss_gauss: 3.108; G_loss: 0.1304; recon_loss: 0.1918; valid_recon_loss: 0.5835
Epoch-184; D_loss_gauss: 3.203; G_loss: 0.1212; recon_loss: 0.2046; valid_recon_loss: 0.5703
Epoch-185; D_loss_gauss: 3.3; G_loss: 0.09942; recon_loss: 0.2158; valid_recon_loss: 0.5062
Epoch-186; D_loss_gauss: 3.451; G_loss: 0.08439; recon_loss: 0.2248; valid_recon_loss: 0.6567
Epoch-187; D_loss_gauss: 3.46; G_loss: 0.1012; recon_loss: 0.226; valid_recon_loss: 0.5605
Epoch-188; D_loss_gauss: 3.534; G_loss: 0.09066; recon_loss: 0.2315; valid_recon_loss: 0.6574
Epoch-189; D_loss_gauss: 3.607; G_loss: 0.08645; recon_loss: 0.1984; valid_

Epoch-268; D_loss_gauss: 3.643; G_loss: 0.04931; recon_loss: 0.1265; valid_recon_loss: 1.437
Epoch-269; D_loss_gauss: 3.366; G_loss: 0.07518; recon_loss: 0.1388; valid_recon_loss: 1.284
Epoch-270; D_loss_gauss: 3.451; G_loss: 0.0668; recon_loss: 0.1194; valid_recon_loss: 0.9755
Epoch-271; D_loss_gauss: 3.482; G_loss: 0.07292; recon_loss: 0.1018; valid_recon_loss: 1.357
Epoch-272; D_loss_gauss: 3.583; G_loss: 0.07386; recon_loss: 0.1068; valid_recon_loss: 1.363
Epoch-273; D_loss_gauss: 3.655; G_loss: 0.07246; recon_loss: 0.1531; valid_recon_loss: 1.162
Epoch-274; D_loss_gauss: 3.658; G_loss: 0.06908; recon_loss: 0.1152; valid_recon_loss: 1.349
Epoch-275; D_loss_gauss: 3.816; G_loss: 0.08167; recon_loss: 0.1149; valid_recon_loss: 1.681
Epoch-276; D_loss_gauss: 3.887; G_loss: 0.05692; recon_loss: 0.1323; valid_recon_loss: 1.271
Epoch-277; D_loss_gauss: 3.987; G_loss: 0.0386; recon_loss: 0.1094; valid_recon_loss: 1.47
Epoch-278; D_loss_gauss: 4.128; G_loss: 0.05339; recon_loss: 0.1032; val

Epoch-357; D_loss_gauss: 4.642; G_loss: 0.04046; recon_loss: 0.1652; valid_recon_loss: 0.8989
Epoch-358; D_loss_gauss: 4.629; G_loss: 0.04212; recon_loss: 0.1312; valid_recon_loss: 1.104
Epoch-359; D_loss_gauss: 4.309; G_loss: 0.048; recon_loss: 0.2164; valid_recon_loss: 1.384
Epoch-360; D_loss_gauss: 4.613; G_loss: 0.06327; recon_loss: 0.1225; valid_recon_loss: 0.8679
Epoch-361; D_loss_gauss: 4.273; G_loss: 0.05935; recon_loss: 0.1235; valid_recon_loss: 1.093
Epoch-362; D_loss_gauss: 4.207; G_loss: 0.04918; recon_loss: 0.0986; valid_recon_loss: 1.198
Epoch-363; D_loss_gauss: 4.388; G_loss: 0.08835; recon_loss: 0.09888; valid_recon_loss: 0.8014
Epoch-364; D_loss_gauss: 4.159; G_loss: 0.0537; recon_loss: 0.1419; valid_recon_loss: 1.261
Epoch-365; D_loss_gauss: 4.087; G_loss: 0.06078; recon_loss: 0.1305; valid_recon_loss: 1.375
Epoch-366; D_loss_gauss: 4.226; G_loss: 0.0768; recon_loss: 0.1116; valid_recon_loss: 1.073
Epoch-367; D_loss_gauss: 4.006; G_loss: 0.04373; recon_loss: 0.09046; 

Epoch-446; D_loss_gauss: 5.171; G_loss: 0.01465; recon_loss: 0.2036; valid_recon_loss: 1.833
Epoch-447; D_loss_gauss: 4.935; G_loss: 0.02995; recon_loss: 0.3403; valid_recon_loss: 2.694
Epoch-448; D_loss_gauss: 5.268; G_loss: 0.03812; recon_loss: 0.1595; valid_recon_loss: 1.735
Epoch-449; D_loss_gauss: 5.762; G_loss: 0.02189; recon_loss: 0.09141; valid_recon_loss: 1.475
Epoch-450; D_loss_gauss: 6.034; G_loss: 0.01792; recon_loss: 0.0682; valid_recon_loss: 2.013
Epoch-451; D_loss_gauss: 6.184; G_loss: 0.006182; recon_loss: 0.1641; valid_recon_loss: 1.865
Epoch-452; D_loss_gauss: 6.253; G_loss: 0.00939; recon_loss: 0.1119; valid_recon_loss: 1.829
Epoch-453; D_loss_gauss: 5.882; G_loss: 0.01619; recon_loss: 0.198; valid_recon_loss: 1.704
Epoch-454; D_loss_gauss: 6.735; G_loss: 0.02453; recon_loss: 0.4546; valid_recon_loss: 2.173
Epoch-455; D_loss_gauss: 6.917; G_loss: 0.005331; recon_loss: 0.1435; valid_recon_loss: 1.278
Epoch-456; D_loss_gauss: 7.199; G_loss: 0.007543; recon_loss: 0.7552

In [None]:
import PIL.Image
from io import BytesIO
import IPython.display
import numpy as np
import math

def hardcodedTransform(array):
    return (1-array/255)

def transformForNetwork(array):
    return Variable(torch.FloatTensor( hardcodedTransform(array.reshape(imagesize['x']*imagesize['y']*imagesize['c'])))).cuda()

def showarray(a,do = False, fmt='png'):
    #a = np.uint8(hardcodedTransform(a)*256)
    f = BytesIO()
    if imagesize['c'] == 1:
        PIL.Image.fromarray(a.reshape(imagesize['x'],imagesize['y'])).save(f, fmt)
        IPython.display.display(IPython.display.Image(data=f.getvalue()))
    if imagesize['c'] == 3:
        plt.imshow(a)
        plt.show()

In [None]:
from PIL import Image
from IPython.display import Image, display

TINY = 1e-15
train_batch_size = 1

my_loader = load_data()

i = 0
max_images = 20
out = []
imgList = np.zeros((max_images,imagesize['x'],imagesize['y'],imagesize['c']))
reconimgList = np.zeros((max_images,imagesize['x'],imagesize['y'],imagesize['c']))
imgZValue = np.zeros((max_images,z_dim))
filepath = "data/train_data"
hdf5_file = tables.open_file(filepath, "r")
# To access images array:
trainset_unlabeled = hdf5_file.root.images
batches = np.split(trainset_unlabeled, len(trainset_unlabeled)/train_batch_size)
for  X in batches:
    if i>=max_images:
        break
    out = X[0]
    imgList[i] = out
    t = Q(transformForNetwork(out))
    imgZValue[i] = t.cpu().data.numpy()
    reconimg =  P(t)
    reconimgList[i] = reconimg.cpu().data.numpy().reshape(imagesize['x'],imagesize['y'],imagesize['c'])
    showarray(out)
    showarray(1-(reconimgList[i]))
    i+=1

In [None]:
showarray(imgList[0])

In [None]:
diff_matrix = np.zeros((max_images,max_images))
for i in range(0,max_images):
    for j in range(0,max_images):
        #print(diff_matrix[i,j])
        diff_matrix[i,j]=math.sqrt(sum(np.power(imgZValue[i] - imgZValue[j],2)))
        #print(imgZValue[i] - imgZValue[j])
#print(diff_matrix)
plt.scatter(imgZValue[:,0],imgZValue[:,1])
plt.show()