In [10]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import h5py


In [11]:
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim

from torch.utils.data import TensorDataset, DataLoader

In [12]:
import sys
# Import FBP
sys.path.append('../FBPConvNet/')
from FBPConvNet import FBPConvNet, Discriminator

sys.path.append('../')
from net_utils import get_datetime

# Generate Example Data

In [15]:
def preprocess(data):
    return torch.Tensor(data).unsqueeze(1)
def target_ones(N,gpu=False):
    if GPU:
        return torch.ones(N,1).cuda()
    else:
        return torch.ones(N,1)
def target_zeros(N,gpu=False):
    if GPU:
        return torch.zeros(N,1).cuda()
    else:
        return torch.zeros(N,1)   

In [13]:
# set path to data
pathtodata = '../EllipseGeneration/RandomLineEllipses15.hdf5'
dataset_size = 200
batch_size = 1

In [16]:
f = h5py.File(pathtodata,'r')
print(f['ellip/training_data'].shape)
fakeinput = preprocess(f['ellip/training_data'][0:dataset_size])
fakelabels = preprocess(f['ellip/training_labels'][0:dataset_size])
reallabels = preprocess(f['ellip/training_labels'][dataset_size:2*dataset_size])
#f.close()

(400, 256, 256)


In [17]:
faketrainset = TensorDataset(fakeinput,fakelabels)
#realset = TensorDataset(reallabels)

faketrainloader = DataLoader(faketrainset,batch_size=batch_size,shuffle=True)
realtrainloader = DataLoader(reallabels, batch_size=batch_size,shuffle=True)

In [None]:
len(realtrainloader)
realiter = iter(realtrainloader)
for i in range(len(realtrainloader)):
    data = realiter.next()
    print(data.size())

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        # Input 1 x 256 x 256 -> 64 x 256 x 256
        self.conv1 = nn.Conv2d(1,64,3,padding=1,stride=1)    
        self.batch1 = nn.BatchNorm2d(64)
        
        # -> 128 x 128 x 128
        self.conv2 = nn.Conv2d(64,128,7,padding=3,stride=2)
        self.batch2 = nn.BatchNorm2d(128)

        # -> 256 x 32 x 32 -> 256 x 22 x 22
        self.conv3 = nn.Conv2d(128,256,5,padding=2,stride=2)
        self.batch3 = nn.BatchNorm2d(256)
        self.conv4 = nn.Conv2d(256,256,7,padding=3,stride=3)
        self.batch4 = nn.BatchNorm2d(256)
        
        # -> 512 x 8 x 8
        self.conv5 = nn.Conv2d(256,512,5,padding=2,stride=3)
        self.batch5 = nn.BatchNorm2d(512)
        
        # -> 1024 x 1 x 1
        self.conv6 = nn.Conv2d(512,1024,5,padding=2,stride=3)
        self.batch6 = nn.BatchNorm2d(1024)
        self.conv7 = nn.Conv2d(1024,1024,3,padding=1,stride=3)
        #self.batch7 = nn.BatchNorm2d(1024)
        
        # Decision layers
        self.conv8 = nn.Conv2d(1024,1024,1)
        self.conv9 = nn.Conv2d(1024,1024,1)
        self.conv10 = nn.Conv2d(1024,1,1)
        
        # Non-Linear Activations
        self.leaky = nn.LeakyReLU()
        self.sigmoid = nn.Sigmoid()
        self.conv2_drop = nn.Dropout2d(p=.2)
        
        
        
    def forward(self,x):
        # Generate Features
        x = self.leaky(self.batch1(self.conv1(x)))
        x = self.leaky(self.batch2(self.conv2(x)))
        x = self.leaky(self.batch3(self.conv3(x)))
        x = self.leaky(self.batch4(self.conv4(x)))
        x = self.leaky(self.batch5(self.conv5(x)))
        x = self.leaky(self.batch6(self.conv6(x)))
        x = self.leaky(self.conv7(x))
        
        # Decision Layers
        x = self.leaky(self.conv2_drop(self.conv8(x)))
        x = self.leaky(self.conv2_drop(self.conv9(x)))
        x = self.sigmoid(self.conv10(x))
        return x[:,:,0,0]
#         x1_1 = self.batch1(self.elu(self.conv1_1(x)))
#         x1_2 = self.batch1(self.elu(self.conv1_2(x1_1)))
#         x1_3 = self.batch1(self.elu(self.conv1_2(x1_2)))
#         x1 = self.maxpool(x1_3)
        
#         x2_1 = self.batch2(self.elu(self.conv2_1(x1)))
#         x2_2 = self.batch2(self.elu(self.conv2_2(x2_1)))
#         x2 = self.maxpool(x2_2)
        
#         x3_1 = self.batch3(self.elu(self.conv3_1(x2)))
#         x3_2 = self.batch3(self.elu(self.conv3_2(x3_1)))
#         x3 = self.maxpool(x3_2)
        
#         x4_1 = self.batch4(self.elu(self.conv4_1(x3)))
#         x4_2 = self.batch4(self.elu(self.conv4_2(x4_1)))

#         x5_1 = self.deconv5(x4_2)
#         x5_2 = torch.cat((x3_2,x5_1),1)
#         x5_3 = self.batch3(self.elu(self.conv5_1(x5_2)))
#         x5 = self.batch3(self.elu(self.conv5_2(x5_3)))
        
#         x6_1 = self.deconv6(x5)
#         x6_2 = torch.cat((x2_2,x6_1),1)
#         x6_3 = self.batch2(self.elu(self.conv6_1(x6_2)))
#         x6 = self.batch2(self.elu(self.conv6_2(x6_3)))
        
#         x7_1 = self.deconv7(x6)
#         x7_2 = torch.cat((x1_3,x7_1),1)
#         x7_3 = self.batch1(self.elu(self.conv7_1(x7_2)))
#         x7 = self.batch1(self.elu(self.conv7_2(x7_3)))
        
#         x8 = self.conv8(x7)
#         y = x8 + x
        
#         return y

In [None]:
net = Discriminator().cuda()
x = Variable(fakeinput)
print(x.size())
out = net.forward(x)
print(out.size())
print(out)

In [None]:
m = nn.Softplus()
input_ = Variable(torch.randn(2))
print(input_)
print(m(input_))

# Training Pseudocode

In [6]:
import time
import os

In [18]:
def train_GANs(G, D, faketrainloader, realtrainloader, num_epochs=500, GPU=False,
              weightpath='./weights/',save_epoch=50,saveweights=True):
    # Create output directory
    weightpath = os.path.join(weightpath,get_datetime())
    os.makedirs(weightpath)
    logpath = os.path.join(weightpath,'log.txt')
    
    with open(logpath, "wt") as text_file:
        print('Epoch\tD Loss\tG Loss\tEpoch Time\tTotal Time',file=text_file)

    num_data = len(realtrainloader)*realtrainloader.batch_size 
    d_losses = np.zeros(num_epochs)
    g_losses = np.zeros(num_epochs)

    # Accumulate log text
    logtxt = ''
    
    # Determine minibatch size
    minibatch = max(1,int(len(realtrainloader))/10)
    
    # Define Loss Function/Optimizer
    bceloss = nn.BCELoss()
    mseloss = nn.MSELoss()

    d_optimizer = optim.Adam(D.parameters(), lr=0.0002)
    g_optimizer = optim.Adam(G.parameters(), lr=0.0002)

    
    G.train()
    trainstart = time.time()
    for epoch in range(num_epochs):
        # Collect loss information
        d_epoch_loss = 0.0
        g_epoch_loss = 0.0
        d_running_loss = 0.0
        g_running_loss = 0.0
        
        epochstart = time.time()

        fakeiter = iter(faketrainloader)
        realiter = iter(realtrainloader)
        Giter = iter(faketrainloader)
        for batch_index in range(len(realtrainloader)):
            ## prepare data
            truelabels = realiter.next()
            fakeinput, _ = fakeiter.next()
            batch_size = truelabels.size(0)


            if GPU:
                truelabels = truelabels.cuda()
                fakeinput = fakeinput.cuda()
    #             fakelabel = fakelabel.cuda()
            d_real_data = Variable(truelabels)
            d_gen_input = Variable(fakeinput)
            d_fake_data = G(d_gen_input).detach() # detach to avoid training G on these labels
            
            ## Train D
            d_optimizer.zero_grad()

            # Train D on real
            d_real_decision = D(d_real_data)
            d_real_error = bceloss(d_real_decision, Variable(target_ones(batch_size,GPU)))
            d_real_error.backward()

            # Train D on fake

            d_fake_decision = D(d_fake_data)
            d_fake_error = bceloss(d_fake_decision, Variable(target_zeros(batch_size,GPU))) 
            d_fake_error.backward()
            d_optimizer.step()
            d_loss = d_real_error+d_fake_error
            
            d_running_loss += d_loss.data[0]
            d_losses[epoch] += d_loss.data[0]
            
        
            ## Train G
            g_fake_input, g_fake_label = Giter.next()
            batch_size = g_fake_input.size(0)

            if GPU:
                g_fake_input = g_fake_input.cuda()
                g_fake_label = g_fake_label.cuda()

            gen_input = Variable(g_fake_input)
            g_fake_data = G(gen_input)
  
            g_optimizer.zero_grad()

            dg_fake_decision = D(g_fake_data)
            g_loss = (10**-3)*bceloss(dg_fake_decision, Variable(target_ones(batch_size,GPU)))
            g_loss +=  mseloss(g_fake_data,Variable(g_fake_label))

            g_loss.backward()
            g_optimizer.step()
            
            g_running_loss += g_loss.data[0]
            g_losses[epoch] += g_loss.data[0]
            
            # print statistics
            if batch_index % minibatch == 0:
                print('\t[%d, %5d] D loss: %.3f, G loss: %.3f, %.3f seconds elapsed' %
                      (epoch + 1, batch_index + 1, d_running_loss / minibatch, 
                       g_running_loss/minibatch, time.time() - epochstart))
                d_running_loss = 0.0
                g_running_loss = 0.0
        # Record epoch statistics
        epochend = time.time()        
        print('Epoch %d Training Time: %.3f seconds\nTotal Elapsed Time: %.3f seconds' %
               (epoch+1, epochend-epochstart,epochend-trainstart))
        
        # log losses
        d_losses[epoch] /= num_data
        g_losses[epoch] /= num_data
        logtxt += '%i\t%f\%f\t%f\t%f\n' % (epoch+1,d_losses[epoch], g_losses[epoch],
                                           epochend-epochstart,epochend-trainstart)

        
        # Save weights
        if (epoch % save_epoch == 0 or epoch == num_epochs-1):
            if saveweights:
                d_outpath = os.path.join(weightpath,'D_epoch_'+str(epoch+1)+'.weights')
                g_outpath = os.path.join(weightpath,'G_epoch_'+str(epoch+1)+'.weights')
                D = D.cpu()
                G = G.cpu()
                torch.save(D.state_dict(),d_outpath)
                torch.save(G.state_dict(),g_outpath)

                if GPU:
                    D = D.cuda()
                    G = G.cuda()
            
            # write loss to logfile
            with open(logpath, "at") as text_file:
                print(logtxt[:-2],file=text_file)
                logtxt = ''

    print('Finished Training')
    return d_losses,g_losses


In [19]:
D = Discriminator()
G = FBPConvNet()

In [21]:
torch.cuda.is_available()

True

In [20]:
torch.cuda.empty_cache()
num_epochs = 10

GPU = True
if GPU:
    D = D.cuda()
    G = G.cuda()
d_losses, g_losses = train_GANs(G,D,faketrainloader,realtrainloader,num_epochs=num_epochs,GPU=GPU)

	[1,     1] D loss: 0.069, G loss: 0.010, 3.016 seconds elapsed
	[1,    21] D loss: 1.313, G loss: 0.020, 21.829 seconds elapsed
	[1,    41] D loss: 0.700, G loss: 0.010, 40.758 seconds elapsed
	[1,    61] D loss: 0.624, G loss: 0.008, 59.681 seconds elapsed
	[1,    81] D loss: 0.403, G loss: 0.007, 78.721 seconds elapsed
	[1,   101] D loss: 0.954, G loss: 0.011, 97.797 seconds elapsed
	[1,   121] D loss: 0.210, G loss: 0.005, 116.852 seconds elapsed
	[1,   141] D loss: 0.906, G loss: 0.008, 135.871 seconds elapsed
	[1,   161] D loss: 0.608, G loss: 0.013, 154.950 seconds elapsed
	[1,   181] D loss: 0.081, G loss: 0.011, 174.036 seconds elapsed
Epoch 1 Training Time: 192.121 seconds
Total Elapsed Time: 192.121 seconds
	[2,     1] D loss: 0.000, G loss: 0.001, 0.967 seconds elapsed
	[2,    21] D loss: 1.642, G loss: 0.006, 19.986 seconds elapsed
	[2,    41] D loss: 0.514, G loss: 0.006, 39.010 seconds elapsed
	[2,    61] D loss: 0.012, G loss: 0.014, 58.037 seconds elapsed
	[2,    81] D

KeyboardInterrupt: 

In [None]:
plt.figure()
plt.subplot(121)
plt.plot(d_losses)
plt.title('Discriminator Losses')
plt.subplot(122)
plt.plot(g_losses)
plt.title('Generator Losses')

In [None]:
fakeiter = iter(faketrainloader)

In [None]:
y,x = fakeiter.next()

GPU = True
if GPU:
    xhat = G(Variable(y.cuda()))
    xhat = xhat.cpu().data
else:
    xhat = G(Variable(y)).data
plt.figure(figsize=(12,4))
plt.subplot(131)
mse =torch.mean((y[0,0,:,:]-x[0,0,:,:])**2)
plt.imshow(y[0,0,...].numpy())
plt.title('Input, mse=%.3f'%(mse))
plt.axis('off')

mse =torch.mean((xhat[0,0,:,:]-x[0,0,:,:])**2)
plt.subplot(132)
plt.imshow(xhat[0,0,...].numpy())
plt.title('Output, mse=%.3f'%(mse))
plt.axis('off')

plt.subplot(133)
plt.imshow(x[0,0,...].numpy())
plt.title('GT')
plt.axis('off')

In [None]:
print(np.min(y.numpy()),np.max(y.numpy()))

In [None]:
G = G.cpu()
torch.save(G.state_dict(),'Gmse_100.weights')
D = D.cpu()
torch.save(D.state_dict(),'Dmse_100.weights')

if GPU:
    G = G.cuda()
    D = D.cuda()
