# Solving Burgers with Transfer Learning

## Import Relevant Packages

In [1]:
import numpy as np
import sciann as sn
from tensorflow.keras import callbacks

---------------------- SCIANN 0.6.1.1 ---------------------- 
For details, check out our review paper and the documentation at: 
 +  "https://arxiv.org/abs/2005.08803", 
 +  "https://www.sciann.com". 

 Need support or would like to contribute, please join sciann`s slack group: 
 +  "https://join.slack.com/t/sciann/shared_invite/zt-ne1f5jlx-k_dY8RGo3ZreDXwz0f~CeA" 
 


In [2]:
x = sn.Variable('x')
t = sn.Variable('t')
d = sn.Variable('d')

v_low = 0.05
v_high = .001

u_low = sn.Functional('u_low', [d,t,x], 8*[20], 'tanh')

In [3]:
from sciann.utils.math import diff, sign, sin

# Low fidelity model
TOL = 0.001
L1 = diff(u_low, t) - v_low * diff(u_low, x, order=2)
L2 = (1 + sign(0-t)) * (u_low - 1 + (1+x)*(1 + d/2))
L3 = (1 + sign(-1-x)) * u_low * (1 + d)
L4 = (1 + sign(x-1)) * u_low

In [4]:
m = sn.SciModel([x,t,d], [L1, L2, L3, L4])

In [5]:
x_data, t_data, d_data = np.meshgrid(
    np.linspace(-1, 1, 10),
    np.linspace(0, 12, 10),
    np.linspace(0, 0.1, 5))

In [7]:
h = m.train(
    [x_data, t_data, d_data], 
    4*['zero'], 
    learning_rate=0.002, 
    epochs=5000,
    callbacks = [callbacks.EarlyStopping(patience = 1)],
    verbose = 0)


Total samples: 500 
Batch size: 64 
Total batches: 8 



## Transfer Learning

In [8]:
# High Fidelity Model
u_high = sn.Functional('u_high', [d,t,x], 8*[20], 'tanh')

L5 = diff(u_high, t) + u_high * diff(u_high, x) - v_high * diff(u_high, x, order=2)
L6 = (1 + sign(0-t)) * (u_high - 1 + (1+x)*(1 + d/2))
L7 = (1 + sign(-1-x)) * u_high * (1 + d)
L8 = (1 + sign(x-1)) * u_high

In [9]:
u_high.set_weights(u_low.get_weights())

u_high.set_trainable(False, [1,2,3,4,5,6,7])



In [10]:
m2 = sn.SciModel([x,t,d], [L5, L6, L7, L8])

In [11]:
h2 = m2.train(
    [x_data, t_data, d_data], 
    4*['zero'], 
    learning_rate=0.002, 
    epochs=5000,
    callbacks = [callbacks.EarlyStopping(patience = 1)],
    verbose = 0)


Total samples: 500 
Batch size: 64 
Total batches: 8 



In [12]:
u_pred = u_high.eval(m2, [x_data, t_data, d_data])