In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.datasets as dset
import torch.nn.functional as F
import torch.optim as optim
from torchvision.utils import save_image

import argparse
import os

import numpy as np
import matplotlib.pyplot as plt

from utils import setBoundaries, makeSamples, PhysicalLoss
from glob import glob

from networks import UNet, GrowingUNet, VariableLossUNet

In [21]:
import numpy as np
from scipy.ndimage import zoom
from scipy.signal import convolve2d

def solve(initial, fixed, tol=1e-15, warm_start=None):
    """Simulate heat diffusion to identify the steady state.

    Parameters
    ----------
    initial: float array
      the initial temperate at every point in the grid
    fixed: bool array
      elements that are set to True will be kept fixed and not allowed to change
    tol: float
      iteration continues until no element changes by more than this
    """
    mask = np.array([[0, 0.25, 0], [0.25, 0.0, 0.25], [0, 0.25, 0]])
    fixed_values = initial[fixed]
    width = initial.shape[0]
    if warm_start is not None:
        array = warm_start
    else:
        array = np.ones(initial.shape) * np.mean(fixed_values)

    array[fixed] = fixed_values

    # Iterate until convergence is reached.
    iterations = 0
    while True:
        iterations += 1
        new_array = convolve2d(array, mask, mode='same', boundary='symm')
        new_array[fixed] = fixed_values
        change = np.max(np.abs(array-new_array))
        if change <= tol:
            return array, iterations
        array = new_array
    return array

In [3]:
batch_size=1; cuda=True; epoch_size=300; epochs=256; experiment='run256_grow_long'; growing=True
image_size=256; learning_rate=0.0002; manualSeed=None; start_size=4

# Set up CUDA
if cuda and torch.cuda.is_available():
    dtype = torch.cuda.FloatTensor
else:
    if torch.cuda.is_available():
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")
    dtype = torch.FloatTensor


In [23]:
# # Make output directory
# os.makedirs(experiment, exist_ok=True)


# if not growing:
#     start_size = image_size

# net = GrowingUNet(dtype, image_size=image_size, start_size=start_size).type(dtype)
# print(net)

# physical_loss = PhysicalLoss(dtype)
# optimizer = optim.Adam(net.parameters(), lr=learning_rate)

# ## Outer training loop
# size = start_size
# epoch = 0
# num_stages = int(np.log2(image_size) - np.log2(start_size)) + 1 if growing else 0
# if num_stages >= 1:
#     epochs = int(epochs*2**(-1*(num_stages-stage)))
# stage = 0

# while True:
#     fixed_sample_0 = torch.zeros(1,1,size,size)
#     fixed_sample_0[:,:,:,0] = 100
#     fixed_sample_0[:,:,0,:] = 0
#     fixed_sample_0[:,:,:,-1] = 100
#     fixed_sample_0[:,:,-1,:] = 0
#     fixed_sample_0 = Variable(fixed_sample_0).cuda()

#     fixed_sample_1 = torch.zeros(1,1,size,size)
#     fixed_sample_1[:,:,:,0] = 100
#     fixed_sample_1[:,:,0,:] = 100
#     fixed_sample_1[:,:,:,-1] = 100
#     fixed_sample_1[:,:,-1,:] = 100
#     fixed_sample_1 = Variable(fixed_sample_1).cuda()

#     boundary = np.zeros((size, size), dtype=np.bool)
#     boundary[0,:] = True
#     boundary[-1,:] = True
#     boundary[:,0] = True
#     boundary[:,-1] = True

#     fixed_solution_0 = solve(fixed_sample_0.cpu().data.numpy()[0,0,:,:], boundary, tol=1e-4)
#     fixed_solution_1 = solve(fixed_sample_1.cpu().data.numpy()[0,0,:,:], boundary, tol=1e-4)

#     ## Inner training loop
#     data = torch.zeros(batch_size,1,size,size)
#     #data = torch.zeros(batch_size,1,image_size,image_size)
#     for _epoch in range(epochs):
#         mean_loss = 0
#         for sample in range(epoch_size):
#             data[:,:,:,0] = np.random.uniform(100)
#             data[:,:,0,:] = np.random.uniform(100)
#             data[:,:,:,-1] = np.random.uniform(100)
#             data[:,:,-1,:] = np.random.uniform(100)
#             img = Variable(data).type(dtype)
#             output = net(img)
#             loss = physical_loss(output)
#             optimizer.zero_grad()
#             loss.backward()
#             optimizer.step()
#             mean_loss += loss.data[0]
#         mean_loss /= epoch_size
#         print('epoch [{}/{}], size {}, loss:{:.4f}'
#               .format(epoch+1, epochs, size, mean_loss))
#         epoch += 1

#         # Plot real samples
#         plt.figure(figsize=(20, 15))
#         f_0 = net(fixed_sample_0)
#         f_1 = net(fixed_sample_1)
#         plt.subplot(2,2,1)
#         plt.imshow(f_0.cpu().data.numpy()[0,0,:,:], vmin=0, vmax=100, cmap=plt.cm.jet)
#         plt.axis('equal')
#         plt.subplot(2,2,2)
#         plt.imshow(f_1.cpu().data.numpy()[0,0,:,:], vmin=0, vmax=100, cmap=plt.cm.jet)
#         plt.axis('equal')
#         plt.subplot(2,2,3)
#         plt.imshow(fixed_solution_0, vmin=0, vmax=100, cmap=plt.cm.jet)
#         plt.axis('equal')
#         plt.subplot(2,2,4)
#         plt.imshow(fixed_solution_1, vmin=0, vmax=100, cmap=plt.cm.jet)
#         plt.axis('equal')
#         plt.savefig('%s/f_1_epoch%d.png' % (experiment, epoch))
#         plt.close()

#         # checkpoint networks
#         if epoch % 50 == 0:
#             torch.save(net.state_dict(), '%s/net_epoch_%d.pth' % (experiment, epoch))

#         if epoch >= epochs:
#             torch.save(net.state_dict(), '%s/net_epoch_%d.pth' % (experiment, epoch))
#             exit()

#     if size < image_size:
#         size *= 2
#         net.setSize(size)
#         stage += 1

GrowingUNet(
  (encoding_layers): ModuleList(
    (0): Conv2d (1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): Conv2d (64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (2): Conv2d (128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): Conv2d (256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): Conv2d (512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (5): Conv2d (512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): Conv2d (512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  )
  (encoding_bns): ModuleList(
    (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
    (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
    (3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
    (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
    (5): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=T

KeyboardInterrupt: 

In [11]:
num_test = 10

files = glob(experiment+"/*.pth")
maximum = 0
for file in files:
    maximum = max(int(file.split("_")[-1].split(".")[0]), maximum)
file = glob(experiment + "/*" + str(maximum) + ".pth")[0]
print(file)

if not growing:
    net = UNet(dtype, image_size=image_size).type(dtype)
else:
#     net = VariableLossUNet(dtype, image_size=image_size).type(dtype)
    net = GrowingUNet(dtype, image_size=image_size, start_size=4).type(dtype)
    net.setSize(image_size)
state_dict = torch.load(file)
# state_dict = torch.load(file, map_location=lambda storage, loc: storage.cuda(1))
net.load_state_dict(state_dict)
print(net)

run256_grow_long/net_epoch_4095.pth
GrowingUNet(
  (encoding_layers): ModuleList(
    (0): Conv2d (1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): Conv2d (64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (2): Conv2d (128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): Conv2d (256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): Conv2d (512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (5): Conv2d (512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): Conv2d (512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (7): Conv2d (512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  )
  (encoding_bns): ModuleList(
    (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
    (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
    (3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
    (4): 

In [22]:
physical_loss = PhysicalLoss(dtype)

boundary = np.zeros((image_size, image_size), dtype=np.bool)
boundary[0,:] = True
boundary[-1,:] = True
boundary[:,0] = True
boundary[:,-1] = True

data = torch.zeros(1,1,image_size,image_size)
error = []
for i in range(num_test):
    data[:,:,:,0] = np.random.uniform(100)
    data[:,:,0,:] = np.random.uniform(100)
    data[:,:,:,-1] = np.random.uniform(100)
    data[:,:,-1,:] = np.random.uniform(100)

    img = Variable(data).type(dtype)
    output = net(img)
    loss = physical_loss(output)

    output = output.cpu().data.numpy()[0,0,:,:]

    solution, iterations = solve(data.cpu().numpy()[0,0,:,:], boundary, tol=1e-3, warm_start=output)
    print("Warm Start: ", iterations)
    solution, iterations = solve(data.cpu().numpy()[0,0,:,:], boundary, tol=1e-3)
    print("Finite Difference: ", iterations)

    error.append(np.mean(np.abs(output-solution))) 
    print("%d Error: %.2f, Loss: %.2f" % (i, error[-1], loss.data[0]))
    # Plot real samples
    plt.figure(figsize=(15, 25))
    XX, YY = np.meshgrid(np.arange(0, image_size), np.arange(0, image_size))
    plt.subplot(3,1,1)
    plt.contourf(XX, YY, data.cpu().numpy()[0,0,:,:], colorinterpolation=50, vmin=0, vmax=100, cmap=plt.cm.jet)
    plt.title("Initial Condition")
    plt.axis('equal')
    plt.subplot(3,1,2)
    plt.contourf(XX, YY, solution, colorinterpolation=50, vmin=0, vmax=100, cmap=plt.cm.jet)
    plt.title("Equilibrium Condition")
    plt.axis('equal')
    plt.subplot(3,1,3)
    plt.contourf(XX, YY, output, colorinterpolation=50, vmin=0, vmax=100, cmap=plt.cm.jet)
    plt.title("Learned Output")
    plt.axis('equal')
    plt.savefig('%s/test_%d.png' % (experiment, i))
    plt.close()

error = np.array(error)
print("error: ", np.mean(error))

Warm Start:  21575
Finite Difference:  6640
0 Error: 9.81, Loss: 0.01
Warm Start:  13003
Finite Difference:  7665
1 Error: 6.57, Loss: 0.01
Warm Start:  24656
Finite Difference:  7457
2 Error: 17.36, Loss: 0.01
Warm Start:  16926
Finite Difference:  8054
3 Error: 6.64, Loss: 0.01


KeyboardInterrupt: 