In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
# y0 = k1*y1 + m1
# yN = k2*yN-1 + m2

# A * x_n-1 - C_n * x_n + B * x_n+1 = -F_n

def TMA(A, B, C, F, k1=0, k2=0, m1=0, m2=0):
    
    N = C.shape[0]
    
    x = np.zeros(N+1)
    
    alpha = np.zeros(N+1)
    beta = np.zeros(N+1)
    
    alpha[0] = k1
    beta[0] = m1

    for n in range(1,N):
        alpha[n] = B[n] / (C[n] - A[n]*alpha[n-1])
        beta[n] = (F[n] + A[n]*beta[n-1]) / (C[n] - A[n]*alpha[n-1])
            
    x[N] = (m2 + k2*beta[N-1]) / (1 - k2*alpha[N-1])
    
    for n in range(N-1, -1, -1):
        x[n] = alpha[n]*x[n+1] + beta[n]
        
    return x

In [30]:
N = {}
N['x'] = 20
N['y'] = 20
N['z'] = 20

N['t'] = 40

In [31]:
lim = {}
lim['x'] = (0, 1)
lim['y'] = (0, 1)
lim['z'] = (0, 1)
lim['t'] = (0, 1)


In [66]:
h = {}
h['x'] = (lim['x'][1] - lim['x'][0]) / (N['x'] - 0.5)
h['y'] = (lim['y'][1] - lim['y'][0]) / (N['y'] - 1) 
h['z'] = (lim['z'][1] - lim['z'][0]) / (N['z'])

h['t'] = (lim['x'][1] - lim['x'][0]) / (N['x'] + 1)

In [67]:
def x(n):
    return lim['x'][0] + h['x']*n

def y(n):
    return lim['y'][0] + h['y']*n - h['y'] / 2

def z(n):
    return lim['z'][0] + h['z']*n

def t(n):
    return lim['t'][0] + h['t']*n

In [68]:
u = np.zeros((N['t']+1, N['x']+1, N['y']+1, N['z']+1))
y1 = np.zeros((N['x']+1, N['y']+1, N['z']+1))
y2 = np.zeros((N['x']+1, N['y']+1, N['z']+1))
y3 = np.zeros((N['x']+1, N['y']+1, N['z']+1))

In [69]:
A_x = np.zeros(N['x'])
B_x = np.zeros(N['x'])
C_x = np.zeros(N['x'])
F_x = np.zeros(N['x'])

In [112]:
## начальные условия
for i in range(N['x']+1):
    for j in range(N['y']+1):
        for k in range(N['z']+1):
            u[0,i,j,k] = np.sin(np.pi*x(i)/2) * np.cos(2*np.pi*y(j)) * np.sin(np.pi*z(k))

In [113]:
for m in range(N['t']):
    
    # y(1) - x
    for j in range(0,N['y']):
        for k in range(0,N['z']):
            for i in range(N['x']):
                
                A_x[i] = 1 / h['x']**2
                C_x[i] = 2 / h['x']**2 + 2 / h['t']**2
                B_x[i] = 1 / h['x']**2
                
                F_x[i] = 2 * u[m,i,j,k] / h['t'] + L2(u[m],i,j,k) + L3(u[m],i,j,k)
            
            y1[:,j,k] = TMA(A_x, B_x, C_x, F_x, k2=1)
            
            
    
    # y(2) - y
    for i in range(1,N['x']+1):
        for k in range(1,N['z']+1):
            for j in range(1,N['y']):
                
                A_x[j] = 1 / h['y']**2
                C_x[j] = 2 / h['y']**2 + 2 / h['t']**2
                B_x[j] = 1 / h['y']**2
                
                F_x[j] = 2 * y1[i,j,k] / h['t'] - L2(u[m],i,j,k)
            
            y2[i,:,k] = TMA(A_x, B_x, C_x, F_x, k1=1, k2=1)
    
    # y(3)
    for i in range(0,N['x']+1):
        for j in range(0,N['y']+1):
            for k in range(1,N['z']):
                
                A_x[k] = 1 / h['y']**2
                C_x[k] = 2 / h['y']**2 + 2 / h['t']**2
                B_x[k] = 1 / h['y']**2
                
                F_x[k] = 2 * y2[i,j,k] / h['t'] - L3(u[m],i,j,k)
            
            y2[i,j,:] = TMA(A_x, B_x, C_x, F_x)
    
    # y^
    for i in range(1,N['x']):
        for j in range(1,N['y']):
            for k in range(1,N['z']):
                u[m+1,i,j,k] = u[m,i,j,k] + h['t'] * (L1(y1,i,j,k) + L2(y2,i,j,k) + L3(y3,i,j,k)) + h['t']*f(m+0.5, i, j, k)
    
    # и граничные условия
    u[m+1, N['x'], :, :] = u[m+1, N['x']-1, :, :]  # правое по x 

    u[m+1, :, 0, :] = u[m+1, :, 1, :] # правое по y
    u[m+1, :, 0, :] = u[m+1, :, 1, :] # левое по y
    

In [114]:
def L(u,i,j,k):
    L1 = (u[i-1,j,k] - 2*u[i,j,k] + u[i+1,j,k]) / h['x']**2
    L2 = (u[i,j-1,k] - 2*u[i,j,k] + u[i,j+1,k]) / h['y']**2
    L3 = (u[i,j,k-1] - 2*u[i,j,k] + u[i,j,k+1]) / h['z']**2
    return L1 + L2 + L3

def L1(u,i,j,k):
    return (u[i-1,j,k] - 2*u[i,j,k] + u[i+1,j,k]) / h['x']**2

def L2(u,i,j,k):
    return (u[i,j-1,k] - 2*u[i,j,k] + u[i,j+1,k]) / h['y']**2

def L3(u,i,j,k):
    return (u[i,j,k-1] - 2*u[i,j,k] + u[i,j,k+1]) / h['z']**2

def f(m,i,j,k):
        return np.exp(t(m)) * np.sin(3*np.pi*x(i)/2) * np.sin(np.pi*z(k))


In [116]:
y3.shape


(21, 21, 21)