In [6]:
"""
r: Zongyi Li and Daniel Zhengyu Huang
"""

import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
from timeit import default_timer
from catheter import *
from utilities3 import *
from Adam import Adam

torch.manual_seed(0)
np.random.seed(0)
torch.cuda.manual_seed(0)
torch.backends.cudnn.deterministic = True
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print("device is ", (device))

device is  cuda


In [3]:
def Lx2length(L_x, L_p, x1, x2, x3, h):
    l0, l1, l2, l3 = -x3, torch.sqrt((x2-x3)**2 + h**2), torch.sqrt((x1-x2)**2 + h**2), L_p+x1
    if L_x < -x3:
        l = L_x
    elif L_x < -x2:
        l = l0 + l1*(L_x + x3)/(x3-x2)
    elif L_x < -x1:
        l = l0 + l1 + l2*(L_x + x2)/(x2-x1)
    else:
        l = l0 + l1 + l2 + L_x+x1

    return l

def d2xy(d, L_p, x1, x2, x3, h):
    
    p0, p1, p2, p3 = torch.tensor([0.0,0.0]), torch.tensor([x3,0.0]), torch.tensor([x2, h]), torch.tensor([x1,0.0])
    v0, v1, v2, v3 = torch.tensor([x3-0,0.0]), torch.tensor([x2-x3,h]), torch.tensor([x1-x2,-h]), torch.tensor([-L_p-x1,0.0])
    l0, l1, l2, l3 = -x3, torch.sqrt((x2-x3)**2 + h**2), torch.sqrt((x1-x2)**2 + h**2), L_p+x1
    
    xx, yy = torch.zeros(d.shape), torch.zeros(d.shape)
    ind = (d < l0)
    xx[ind] = d[ind]*v0[0]/l0 + p0[0]
    yy[ind] = d[ind]*v0[1]/l0 + p0[1]
    
    ind = torch.logical_and(d < l0 + l1, d>=l0)
    xx[ind] = (d[ind]-l0)*v1[0]/l1 + p1[0] 
    yy[ind] = (d[ind]-l0)*v1[1]/l1 + p1[1]
    
    ind = torch.logical_and(d < l0 + l1 + l2, d>=l0 + l1)
    xx[ind] = (d[ind]-l0-l1)*v2[0]/l2 + p2[0]
    yy[ind] = (d[ind]-l0-l1)*v2[1]/l2 + p2[1]
    
    ind = (d>=l0 + l1 + l2)
    xx[ind] = (d[ind]-l0-l1-l2)*v3[0]/l3 + p3[0]
    yy[ind] = (d[ind]-l0-l1-l2)*v3[1]/l3 + p3[1]
    

    return xx, yy

def catheter_mesh_1d_total_length(L_x, L_p, x2, x3, h, N_s):
    x1 = -0.5*L_p
    # ncy = 20
    
    n_periods = torch.floor(L_x / L_p)
    L_x_last_period = L_x - n_periods*L_p
    L_p_s = ((x1 + L_p) + (0 - x3) + torch.sqrt((x2 - x1)**2 + h**2) + torch.sqrt((x3 - x2)**2 + h**2))
    L_s = L_p_s*n_periods + Lx2length(L_x_last_period, L_p, x1, x2, x3, h)
    
    # from 0
    d_arr = torch.linspace(0, 1, N_s) * L_s
    
    # TODO do not compute gradient for floor
    period_arr = torch.floor(d_arr / L_p_s).detach()
    d_arr -= period_arr * L_p_s

    
    xx, yy = d2xy(d_arr, L_p, x1, x2, x3, h)
        
    xx = xx - period_arr*L_p
    
    
    X_Y = torch.zeros((1, N_s, 2), dtype=torch.float).to(device)
    X_Y[0, :, 0], X_Y[0, :, 1] = xx, yy
    return X_Y, xx, yy




In [13]:
################################################################
# inverse optimization for 1d
################################################################


model = torch.load("catheter_plain_length_model_1d1000", map_location=device)
print(count_params(model))

learning_rate = 0.001
epochs = 5001
step_size = 1000
gamma = 0.1
L_x = 500
N_s = 2001
xx_mask = (torch.linspace(1.0, 0, N_s) * (-L_x)).to(device)
# learning_rate = 0.001
# epochs = 5001
# step_size = 500
# gamma = 0.5

# constraints   
#               60 < L_p < 250
#               x1 = -0.5L_p 
#               -L_p/4 < x2 - x3 < L_p/4
#               15 < x3 - x1 < L_p/4
#               20 < h < 30

loss_min = np.inf
L_p_min, x2_min, x3_min = np.NaN, np.NaN, np.NaN
mesh_min = np.NaN
density_min = np.NaN
nLp = 190
for L_p in np.linspace(61, 249, nLp):
    print("L_p = ", L_p)
    x1 = -L_p/2
    for dx3_x1 in np.linspace(15, L_p/2, np.int64(np.round(2*(L_p/2 - 15)))+1):
        x3 = x1 + dx3_x1
        for x2 in np.linspace(-L_p, 0, np.int64(np.round(2*L_p))+1):
            for h in np.linspace(20, 30, 21):
                
                L_p, x2, x3, h =  torch.tensor(L_p, dtype=torch.float), torch.tensor(x2, dtype=torch.float), torch.tensor(x3, dtype=torch.float), torch.tensor(h, dtype=torch.float)

                x, XC, YC = catheter_mesh_1d_total_length(L_x, L_p, x2, x3, h, N_s)

                out = model(x).squeeze()
                loss = -torch.sum(torch.matmul(out, xx_mask))* L_x/N_s
                if loss < loss_min:
                    L_p_min, x2_min, x3_min, h_min = L_p, x2, x3, h
                    mesh_min = np.copy(x.detach().cpu().numpy())
                    density_min = np.copy(out.detach().cpu().numpy())
                    loss_min = loss
                    print(L_p_min, x2_min, x3_min, h_min, loss_min)
                
plt.figure(figsize=(5,4))
plt.title("Bacterial population")
plt.plot(xx_mask.detach().cpu().numpy(), density_min, "-o", fillstyle='none', markevery=len(xx_mask)//10, label="Prediction")
plt.plot(mesh_min[0, :, 0], mesh_min[0, :, 1], color="r", label="Design")

plt.legend()
plt.show()



1336001
L_p =  61.0
tensor(61.) tensor(-61.) tensor(-15.5000) tensor(20.) tensor(1383478., device='cuda:0', grad_fn=<DivBackward0>)
tensor(61.) tensor(-61.) tensor(-15.5000) tensor(20.5000) tensor(1361050., device='cuda:0', grad_fn=<DivBackward0>)
tensor(61.) tensor(-61.) tensor(-15.5000) tensor(21.) tensor(1337339.2500, device='cuda:0', grad_fn=<DivBackward0>)
tensor(61.) tensor(-61.) tensor(-15.5000) tensor(21.5000) tensor(1312431.2500, device='cuda:0', grad_fn=<DivBackward0>)
tensor(61.) tensor(-61.) tensor(-15.5000) tensor(22.) tensor(1285901.6250, device='cuda:0', grad_fn=<DivBackward0>)
tensor(61.) tensor(-61.) tensor(-15.5000) tensor(22.5000) tensor(1258725., device='cuda:0', grad_fn=<DivBackward0>)
tensor(61.) tensor(-61.) tensor(-15.5000) tensor(23.) tensor(1230547.8750, device='cuda:0', grad_fn=<DivBackward0>)
tensor(61.) tensor(-61.) tensor(-15.5000) tensor(23.5000) tensor(1201949.5000, device='cuda:0', grad_fn=<DivBackward0>)
tensor(61.) tensor(-61.) tensor(-15.5000) tensor



tensor(61.) tensor(-60.5000) tensor(-15.5000) tensor(30.) tensor(854852.6250, device='cuda:0', grad_fn=<DivBackward0>)
tensor(61.) tensor(-60.) tensor(-15.5000) tensor(30.) tensor(848431.2500, device='cuda:0', grad_fn=<DivBackward0>)
tensor(61.) tensor(-59.5000) tensor(-15.5000) tensor(30.) tensor(842640.3750, device='cuda:0', grad_fn=<DivBackward0>)
tensor(61.) tensor(-59.) tensor(-15.5000) tensor(30.) tensor(837376.1250, device='cuda:0', grad_fn=<DivBackward0>)
tensor(61.) tensor(-58.5000) tensor(-15.5000) tensor(30.) tensor(832487.1250, device='cuda:0', grad_fn=<DivBackward0>)
tensor(61.) tensor(-58.) tensor(-15.5000) tensor(30.) tensor(827586.3750, device='cuda:0', grad_fn=<DivBackward0>)
tensor(61.) tensor(-57.5000) tensor(-15.5000) tensor(30.) tensor(822440.3125, device='cuda:0', grad_fn=<DivBackward0>)
tensor(61.) tensor(-57.) tensor(-15.5000) tensor(30.) tensor(817428.1875, device='cuda:0', grad_fn=<DivBackward0>)
tensor(61.) tensor(-56.5000) tensor(-15.5000) tensor(30.) tensor

KeyboardInterrupt: 