In [1]:
#!/usr/bin/env python
# coding: utf-8

# In[1]:


#!/usr/bin/env python
# coding: utf-8

# """
# @author: Zongyi Li
# This file is the Fourier Neural Operator for 3D problem such as the Navier-Stokes equation discussed in Section 5.3 in the [paper](https://arxiv.org/pdf/2010.08895.pdf),
# which takes the 2D spatial + 1D temporal equation directly as a 3D problem
# FNO: https://github.com/zongyi-li/fourier_neural_operator
# UNO: https://github.com/ashiq24/UNO
# """


from scipy import signal
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt


import operator
from functools import reduce
from functools import partial

from timeit import default_timer
import pickle
torch.manual_seed(0)
np.random.seed(0)
device = torch.device('cuda:0')


class GaussianNormalizer(object):
    def __init__(self, x, eps=0.00001):
        super(GaussianNormalizer, self).__init__()

        self.mean = torch.mean(x)
        self.std = torch.std(x)
        self.eps = eps

    def encode(self, x):
        x = (x - self.mean) / (self.std + self.eps)
        return x

    def decode(self, x, sample_idx=None):
        x = (x * (self.std + self.eps)) + self.mean
        return x

    def cuda(self):
        self.mean = self.mean.to(device)
        self.std = self.std.to(device)

    def cpu(self):
        self.mean = self.mean.cpu()
        self.std = self.std.cpu()

# normalization, pointwise gaussian


class UnitGaussianNormalizer(object):
    def __init__(self, x, eps=0.00001):
        super(UnitGaussianNormalizer, self).__init__()

        # x could be in shape of ntrain*n or ntrain*T*n or ntrain*n*T
        self.mean = torch.mean(x, 0)
        self.std = torch.std(x, 0)
        self.eps = eps

    def encode(self, x):
        x = (x - self.mean) / (self.std + self.eps)
        return x

    def decode(self, x, sample_idx=None):
        if sample_idx is None:
            std = self.std + self.eps  # n
            mean = self.mean
        else:
            if len(self.mean.shape) == len(sample_idx[0].shape):
                std = self.std[sample_idx] + self.eps  # batch*n
                mean = self.mean[sample_idx]
            if len(self.mean.shape) > len(sample_idx[0].shape):
                std = self.std[:, sample_idx] + self.eps  # T*batch*n
                mean = self.mean[:, sample_idx]

        # x is in shape of batch*n or T*batch*n
        x = (x * std) + mean
        return x

    def cuda(self):
        self.mean = self.mean.to(device)
        self.std = self.std.to(device)

    def cpu(self):
        self.mean = self.mean.cpu()
        self.std = self.std.cpu()

# normalization, pointwise gaussian


class InputNormalizer(object):
    def __init__(self, x, eps=0.00001, nmax=None):
        super(InputNormalizer, self).__init__()

        # x could be in shape of ntrain*n or ntrain*T*n or ntrain*n*T
        if max is not None:
            self.mean = torch.mean(x[:nmax], dim=(1, 2, 3))
            self.mean = self.mean.mean(dim=0)
            self.std = torch.std(x[:nmax], dim=(1, 2, 3))
            self.std = self.std.mean(dim=0)
        else:
            self.mean = torch.mean(x, dim=(1, 2, 3))
            self.mean = self.mean.mean(dim=0)
            self.std = torch.std(x, dim=(1, 2, 3))
            self.std = self.std.mean(dim=0)
        self.eps = eps

    def encode(self, x):
        x = (x - self.mean) / (self.std + self.eps)
        return x

    def decode(self, x, sample_idx=None):
        if sample_idx is None:
            std = self.std + self.eps  # n
            mean = self.mean
        else:
            if len(self.mean.shape) == len(sample_idx[0].shape):
                std = self.std[sample_idx] + self.eps  # batch*n
                mean = self.mean[sample_idx]
            if len(self.mean.shape) > len(sample_idx[0].shape):
                std = self.std[:, sample_idx] + self.eps  # T*batch*n
                mean = self.mean[:, sample_idx]

        # x is in shape of batch*n or T*batch*n
        x = (x * std) + mean
        return x

    def cuda(self):
        self.mean = self.mean.to(device)
        self.std = self.std.to(device)

    def cpu(self):
        self.mean = self.mean.cpu()
        self.std = self.std.cpu()


# loss function with rel/abs Lp loss
class LpLoss(object):
    def __init__(self, d=2, p=2, size_average=True, reduction=True):
        super(LpLoss, self).__init__()

        # Dimension and Lp-norm type are postive
        assert d > 0 and p > 0

        self.d = d
        self.p = p
        self.reduction = reduction
        self.size_average = size_average

    def abs(self, x, y):
        num_examples = x.size()[0]

        # Assume uniform mesh
        h = 1.0 / (x.size()[1] - 1.0)

        all_norms = (h**(self.d / self.p)) *             torch.norm(x.view(num_examples, -1) - y.view(num_examples, -1), self.p, 1)

        if self.reduction:
            if self.size_average:
                return torch.mean(all_norms)
            else:
                return torch.sum(all_norms)

        return all_norms

    def rel(self, x, y):
        num_examples = x.size()[0]

        diff_norms = torch.norm(x.reshape(
            num_examples, -1) - y.reshape(num_examples, -1), self.p, 1)
        y_norms = torch.norm(y.reshape(num_examples, -1), self.p, 1)

        if self.reduction:
            if self.size_average:
                return torch.mean(diff_norms / y_norms)
            else:
                return torch.sum(diff_norms / y_norms)

        return diff_norms / y_norms

    def __call__(self, x, y):
        return self.rel(x, y)

################################################################
# Code of UNO3D starts
# Pointwise and Fourier Layer
################################################################

class SpectralConv3d_UNO(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            D1,
            D2,
            D3,
            modes1=None,
            modes2=None,
            modes3=None):
        super(SpectralConv3d_UNO, self).__init__()

        """
        3D Fourier layer. It does FFT, linear transform, and Inverse FFT.
        D1, D2, D3 are output dimensions (x,y,t)
        modes1,modes2,modes3 = Number of fourier coefficinets to consider along each spectral dimesion
        """

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.d1 = D1
        self.d2 = D2
        self.d3 = D3
        if modes1 is not None:
            # Number of Fourier modes to multiply, at most floor(N/2) + 1
            self.modes1 = modes1
            self.modes2 = modes2
            self.modes3 = modes3
        else:
            self.modes1 = D1  # Will take the maximum number of possiblel modes for given output dimension
            self.modes2 = D2
            self.modes3 = D3 // 2 + 1

        self.scale = (1 / (2 * in_channels))**(1.0 / 2.0)
        self.weights1 = nn.Parameter(
            self.scale *
            torch.rand(
                in_channels,
                out_channels,
                self.modes1,
                self.modes2,
                self.modes3,
                dtype=torch.cfloat))
        self.weights2 = nn.Parameter(
            self.scale *
            torch.rand(
                in_channels,
                out_channels,
                self.modes1,
                self.modes2,
                self.modes3,
                dtype=torch.cfloat))
        self.weights3 = nn.Parameter(
            self.scale *
            torch.rand(
                in_channels,
                out_channels,
                self.modes1,
                self.modes2,
                self.modes3,
                dtype=torch.cfloat))
        self.weights4 = nn.Parameter(
            self.scale *
            torch.rand(
                in_channels,
                out_channels,
                self.modes1,
                self.modes2,
                self.modes3,
                dtype=torch.cfloat))

    # Complex multiplication
    def compl_mul3d(self, input, weights):
        # (batch, in_channel, x,y,t ), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t)
        return torch.einsum("bixyz,ioxyz->boxyz", input, weights)

    def forward(self, x, D1=None, D2=None, D3=None):
        """
        D1,D2,D3 are the output dimensions (x,y,t)
        """
        if D1 is not None:
            self.d1 = D1
            self.d2 = D2
            self.d3 = D3

        batchsize = x.shape[0]
        # Compute Fourier coeffcients up to factor of e^(- something constant)
        x_ft = torch.fft.rfftn(x, dim=[-3, -2, -1], norm = 'forward')

        # Multiply relevant Fourier modes
        out_ft = torch.zeros(
            batchsize,
            self.out_channels,
            self.d1,
            self.d2,
            self.d3 // 2 + 1,
            dtype=torch.cfloat,
            device=x.device)

        out_ft[:, :, :self.modes1, :self.modes2, :self.modes3] = self.compl_mul3d(x_ft[:, :, :self.modes1, :self.modes2, :self.modes3], self.weights1)
        out_ft[:, :, -self.modes1:, :self.modes2, :self.modes3] = self.compl_mul3d(x_ft[:, :, -self.modes1:, :self.modes2, :self.modes3], self.weights2)
        out_ft[:, :, :self.modes1, -self.modes2:, :self.modes3] = self.compl_mul3d(x_ft[:, :, :self.modes1, -self.modes2:, :self.modes3], self.weights3)
        out_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3] = self.compl_mul3d(x_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3], self.weights4)

        # Return to physical space
        x = torch.fft.irfftn(out_ft, s=(self.d1, self.d2, self.d3) , norm = 'forward')
        return x


class pointwise_op_3D(nn.Module):
    def __init__(self, in_channel, out_channel, dim1, dim2, dim3):
        super(pointwise_op_3D, self).__init__()
        self.conv = nn.Conv3d(int(in_channel), int(out_channel), 1)
        self.dim1 = int(dim1)
        self.dim2 = int(dim2)
        self.dim3 = int(dim3)

    def forward(self, x, dim1=None, dim2=None, dim3=None):
        """
        dim1,dim2,dim3 are the output dimensions (x,y,t)
        """
        if dim1 is None:
            dim1 = self.dim1
            dim2 = self.dim2
            dim3 = self.dim3
        x_out = self.conv(x)
        
        ft = torch.fft.rfftn(x_out,dim=[-3,-2,-1])
        ft_u = torch.zeros_like(ft)
        ft_u[:, :, :(dim1//2), :(dim2//2), :(dim3//2)] = ft[:, :, :(dim1//2), :(dim2//2), :(dim3//2)]
        ft_u[:, :, -(dim1//2):, :(dim2//2), :(dim3//2)] = ft[:, :, -(dim1//2):, :(dim2//2), :(dim3//2)]
        ft_u[:, :, :(dim1//2), -(dim2//2):, :(dim3//2)] = ft[:, :, :(dim1//2), -(dim2//2):, :(dim3//2)]
        ft_u[:, :, -(dim1//2):, -(dim2//2):, :(dim3//2)] = ft[:, :, -(dim1//2):, -(dim2//2):, :(dim3//2)]

        x_out = torch.fft.irfftn(ft_u, s=(dim1, dim2, dim3))

        
        x_out = torch.nn.functional.interpolate(x_out, size=(
            dim1, dim2, dim3), mode='trilinear', align_corners=True)
        return x_out

class OperatorBlock_3D(nn.Module,):
    """
    To turn to normalization set Normalize = True
    To have linear operator set Non_Lin = False
    """
    def __init__(self, in_channel, out_channel,res1, res2,res3,modes1,modes2,modes3, Normalize = True, Non_Lin = True):
        super(OperatorBlock_3D,self).__init__()
        self.conv = SpectralConv3d_UNO(in_channel, out_channel, res1,res2,res3,modes1,modes2,modes3)
        self.w = pointwise_op_3D(in_channel, out_channel, res1,res2,res3)
        self.normalize = Normalize
        self.non_lin = Non_Lin
        if Normalize:
            self.normalize_layer = torch.nn.InstanceNorm3d(out_channel,affine=True)


    def forward(self,x, res1 = None, res2 = None, res3 = None):

        x1_out = self.conv(x,res1,res2,res3)
        x2_out = self.w(x,res1,res2,res3)
        x_out = x1_out + x2_out
        if self.normalize:
            x_out = self.normalize_layer(x_out)
        if self.non_lin:
            x_out = F.gelu(x_out)
        return x_out

#######
## New 3D Neural operator
## Without any domain extension of the input function
## Following neural operator is desinged for predicting next 40 time steps from the input (Initial 10 time steps).
########
class Uno3D(nn.Module):
    def __init__(self, in_width, width,pad = 0, factor = 1, pad_both = False):
        super(Uno3D, self).__init__()

        self.in_width = in_width # input channel
        self.width = width 
        
        self.padding = pad  # pad the domain if input is non-periodic
        self.pad_both = pad_both
        self.fc_n1 = nn.Linear(self.in_width, self.width//2)

        self.fc0 = nn.Linear(self.width//2, self.width) # input channel is 3: (a(x, y), x, y)
        
        self.conv0 = OperatorBlock_3D(self.width, 2*factor*self.width, 48, 48, 96, 24, 24, 24)
        
        self.conv1 = OperatorBlock_3D(2*factor*self.width, 4*factor*self.width, 32, 32, 64, 16, 16, 16)
        
        self.conv2 = OperatorBlock_3D(4*factor*self.width, 8*factor*self.width, 16, 16, 32, 8, 8, 8)
        
        self.conv3 = OperatorBlock_3D(8*factor*self.width, 16*factor*self.width, 8, 8, 16, 4, 4, 4)
        
        self.conv4 = OperatorBlock_3D(16*factor*self.width, 16*factor*self.width, 8, 8, 16, 4, 4, 4)
        
        self.conv5 = OperatorBlock_3D(16*factor*self.width, 8*factor*self.width, 16, 16, 32, 4, 4, 4) 
        
        self.conv6 = OperatorBlock_3D(8*factor*self.width, 4*factor*self.width, 32, 32, 64, 8, 8, 8)
        
        self.conv7 = OperatorBlock_3D(8*factor*self.width, 2*factor*self.width, 48, 48, 96, 16, 16, 16)
        
        self.conv8 = OperatorBlock_3D(4*factor*self.width, 2*self.width, 64, 64, 128, 24, 24, 32) # will be reshaped

        self.fc1 = nn.Linear(3*self.width, 4*self.width)
        self.fc2 = nn.Linear(4*self.width, 2)
        
        #self.bn_fc_1 = torch.nn.BatchNorm3d(self.width)
        #self.bn_fc0 = torch.nn.InstanceNorm3d(self.width)
        #self.bn_fc1 = torch.nn.InstanceNorm3d(4*self.width)

    def forward(self, x, time_grid = 128):
        grid = self.get_grid(x.shape, x.device)
        x = torch.cat((x, grid), dim=-1)
        x_fc = self.fc_n1(x)
        x_fc = F.gelu(x_fc)
        x_fc0 = self.fc0(x_fc)
        x_fc0 = F.gelu(x_fc0)
        
        x_fc0 = x_fc0.permute(0, 4, 1, 2, 3)
        
        #x_fc0 = F.pad(x_fc0, [0,self.padding,0,0,0,0],mode ='constant')
        
        D1,D2,D3 = x_fc0.shape[-3],x_fc0.shape[-2],x_fc0.shape[-1]

        x_c0 = self.conv0(x_fc0)
        x_c1 = self.conv1(x_c0)
        x_c2 = self.conv2(x_c1)
        
        x_c3 = self.conv3(x_c2)
        x_c4 = self.conv4(x_c3)
        x_c5 = self.conv5(x_c4)
        
        x_c6 = self.conv6(x_c5)
        x_c6 = torch.cat([x_c6, torch.nn.functional.interpolate(x_c1, size = (x_c6.shape[2], x_c6.shape[3],x_c6.shape[4]),mode = 'trilinear',align_corners=True)], dim=1)
        
        x_c7 = self.conv7(x_c6)
        x_c7 = torch.cat([x_c7, torch.nn.functional.interpolate(x_c0, size = (x_c7.shape[2], x_c7.shape[3],x_c7.shape[4]),mode = 'trilinear',align_corners=True)], dim=1)
        
        x_c8 = self.conv8(x_c7,D1,D2,time_grid+self.padding)

        x_c8 = torch.cat([x_c8,torch.nn.functional.interpolate(x_fc0, size = (x_c8.shape[2], x_c8.shape[3],x_c8.shape[4]),mode = 'trilinear',align_corners=True)], dim=1)
        
        if self.padding!=0:
            if self.pad_both:
                x_c8 = x_c8[...,self.padding//2:-self.padding//2]
            else:
                x_c8 = x_c8[...,:-self.padding]

        x_c8 = x_c8.permute(0, 2, 3, 4, 1)

        x_fc1 = self.fc1(x_c8)
        #x_fc1 = self.bn_fc1(x_fc1.permute(0, 4, 1, 2, 3)).permute(0, 2, 3, 4, 1)
        x_fc1 = F.gelu(x_fc1)
        x_out = self.fc2(x_fc1)
        
        return x_out
    
    def get_grid(self, shape, device):
        batchsize, size_x, size_y, size_z = shape[0], shape[1], shape[2], shape[3]
        gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1, 1, 1).repeat([batchsize, 1, size_y, size_z, 1])
        gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float)
        gridy = gridy.reshape(1, 1, size_y, 1, 1).repeat([batchsize, size_x, 1, size_z, 1])
        gridz = torch.tensor(np.linspace(0, 1, size_z), dtype=torch.float)
        gridz = gridz.reshape(1, 1, 1, size_z, 1).repeat([batchsize, size_x, size_y, 1, 1])
        return torch.cat((gridx, gridy, gridz), dim=-1).to(device)

def count_params(model):
    c = 0
    for p in list(model.parameters()):
        c += reduce(operator.mul,
                    list(p.size() + (2,) if p.is_complex() else p.size()))
    return c



In [2]:
################################################################
# configs
################################################################

sub = 4
npad = 0
S = 64 - 2 * npad
T_in = 1
T = 128 - 2 * npad
T_max = 400
nx, ny = 64, 64

ntrain = 8
ntest = 2
width = 16

batch_size = 8

in_channels = 3

epochs = 100

learning_rate = 1e-3

runtime = np.zeros(2, )
t1 = default_timer()

################################################################
# load data
################################################################
datapath = './'
Vsset = torch.tensor(np.load(datapath + 'model/Vs.npy')).float()
Vpset = torch.tensor(np.load(datapath + 'model/Vp.npy')).float()
srcxy = torch.tensor(np.load(datapath + 'model/srcxy.npy'), dtype=torch.int)
srct = np.zeros(T_max)
tmp = np.load(datapath + 'model/srct.npy')[0]
srct[:tmp.size] = tmp
srct = torch.tensor(srct, dtype=torch.float)
Vpset = F.interpolate(Vpset.view(-1, 1, nx, ny),
                      size=(S, S), antialias=True, mode='bilinear')

Vsset = F.interpolate(Vsset.view(-1, 1, nx, ny),
                      size=(S, S), antialias=True, mode='bilinear')


train_a = torch.zeros((ntrain, S, S, T, in_channels))
train_u = torch.zeros((ntrain, S, S, T, 2))

test_a = torch.zeros((ntest, S, S, T, in_channels))
test_u = torch.zeros((ntest, S, S, T, 2))

x = torch.arange(nx)
y = torch.arange(ny)
xx, yy = torch.meshgrid(x, y, indexing="ij")
src_width = 2

offset_train = 0
print("Building training set")
for i in range(ntrain):
    u_wf = torch.tensor(np.load(datapath +
                                'waveform/No' +
                                str(offset_train +
                                    i) +
                                '.npy')).float().permute(3, 0, 1, 2).view(1, 2, nx, ny, -
                                                                          1)
    u_wf = F.interpolate(u_wf, size=(S, S, T), mode='trilinear').view(
        2, S, S, -1).permute(1, 2, 3, 0)
    train_a[i, :, :, :, 0] = Vpset[offset_train + i, :, :].view(S, S, 1)
    train_a[i, :, :, :, 1] = Vsset[offset_train + i, :, :].view(S, S, 1)

    t_start = 0
    t_stop = t_start + T
    train_u[i] = u_wf[:, :, t_start:t_stop, :]
    spatial_func = torch.exp(-(xx - srcxy[offset_train + i, 0]) ** 2 / src_width ** 2) *         torch.exp(-(yy - srcxy[offset_train + i, 1]) ** 2 / src_width ** 2)
    spatial_func = F.interpolate(
        spatial_func.view(
            1, 1, nx, ny), size=(
            S, S), antialias=True, mode='bilinear')
    train_a[i, :, :, :, 2] = spatial_func.view(
        S, S, 1) * torch.abs(F.interpolate(srct.view(1,1,-1), size=( T), mode='linear'))

offset_test = ntrain
print("Building test set")
for i in range(ntest):
    u_wf = torch.tensor(np.load(datapath +
                                'waveform/No' +
                                str(offset_test +
                                    i) +
                                '.npy')).float().permute(3, 0, 1, 2).view(1, 2, nx, ny, -
                                                                          1)
    u_wf = F.interpolate(u_wf, size=(S, S, T), mode='trilinear').view(
        2, S, S, -1).permute(1, 2, 3, 0)
    test_a[i, :, :, :, 0] = Vpset[offset_test + i, :, :].view(S, S, 1)
    test_a[i, :, :, :, 1] = Vsset[offset_test + i, :, :].view(S, S, 1)

    t_start = 0
    t_stop = t_start + T
    test_u[i] = u_wf[:, :, t_start:t_stop, :]
    spatial_func = torch.exp(-(xx - srcxy[offset_test + i, 0]) ** 2 / src_width ** 2) *         torch.exp(-(yy - srcxy[offset_test + i, 1]) ** 2 / src_width ** 2)
    spatial_func = F.interpolate(
        spatial_func.view(
            1, 1, nx, ny), size=(
            S, S), antialias=True, mode='bilinear')
    test_a[i, :, :, :, 2] = spatial_func.view(
        S, S, 1) * torch.abs(F.interpolate(srct.view(1,1,-1), size=( T), mode='linear'))

print(train_a.shape)
print(train_u.shape)

a_normalizer = InputNormalizer(train_a)
y_normalizer = InputNormalizer(train_u)
a_normalizer.cuda()
y_normalizer.cuda()

train_loader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(
        train_a,
        train_u),
    batch_size=batch_size,
    shuffle=True)
test_loader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(
        test_a,
        test_u),
    batch_size=batch_size,
    shuffle=False)

t2 = default_timer()

print('preprocessing finished, time used:', t2 - t1)


out, x, y = [], [], []

model = Uno3D(in_channels + 3, width, pad=0).to(device)

#model = UNO_3D(in_channels + 3, width, S, T, pad=npad).to(device)
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=learning_rate,
    weight_decay=1e-5)




print(count_params(model))


Building training set
Building test set
torch.Size([8, 64, 64, 128, 3])
torch.Size([8, 64, 64, 128, 2])
preprocessing finished, time used: 4.160167830064893
694319914


In [3]:
def train():
    L2 = LpLoss(p=2, size_average=False)
    L1 = LpLoss(p=1, size_average=False)
    # y_normalizer.cuda()
    losstrain = np.zeros(epochs)
    losstest = np.zeros(epochs)
    l2train = np.zeros(epochs)
    l2test = np.zeros(epochs)
    for ep in range(epochs):
        model.train()
        t1 = default_timer()
        train_loss = 0
        train_L2 = 0.0
        train_L1 = 0.0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)

            y = y_normalizer.encode(y)
            x = a_normalizer.encode(x)

            optimizer.zero_grad()
            out = model(x)

            L2_loss = L2(out.view(x.shape[0], -1), y.view(x.shape[0], -1))
            L1_loss = L1(out.view(x.shape[0], -1), y.view(x.shape[0], -1))
            loss = 0.9 * L1_loss + 0.1 * L2_loss
            loss.backward()

            optimizer.step()
            with torch.no_grad():
                train_loss += loss.item()
                train_L2 += L2_loss.item()
                train_L1 += L1_loss.item()

        model.eval()
        test_loss = 0.0
        test_L1 = 0.0
        test_L2 = 0.0
        with torch.no_grad():
            for x, y in test_loader:
                x, y = x.to(device), y.to(device)
                y = y_normalizer.encode(y)
                x = a_normalizer.encode(x)
                out = model(x)

                L2_loss = L2(out.view(x.shape[0], -1), y.view(x.shape[0], -1))
                L1_loss = L1(out.view(x.shape[0], -1), y.view(x.shape[0], -1))
                loss = 0.9 * L1_loss + 0.1 * L2_loss
                test_loss += loss.item()
                test_L2 += L2_loss.item()
                test_L1 += L1_loss.item()

        train_loss /= ntrain
        test_loss /= ntest
        train_L1 /= ntrain
        train_L2 /= ntrain
        test_L1 /= ntest
        test_L2 /= ntest

        t2 = default_timer()
        print(
            ep,
            t2 - t1,
            train_loss,
            test_loss,
            train_L2,
            test_L2,
            train_L1,
            test_L1)

        eps = 1e-3
        losstrain[ep] = train_loss
        losstest[ep] = test_loss
        l2train[ep] = train_L2
        l2test[ep] = test_L2
    return losstrain, losstest

losstrain, losstest = train()



0 1.314715189859271 1.0706849098205566 1.038224458694458 1.0367131233215332 1.0158281326293945 1.0744596719741821 1.0407129526138306
1 1.2898991331458092 1.045721411705017 1.0225578546524048 1.0208604335784912 1.007822036743164 1.0484837293624878 1.0241951942443848
2 1.2622177228331566 1.0262469053268433 1.010532021522522 1.0094645023345947 1.0030450820922852 1.0281116962432861 1.0113639831542969
3 1.284719342365861 1.0117501020431519 1.0023083686828613 1.0026497840881348 1.0013346672058105 1.0127612352371216 1.0024166107177734
4 1.2741169054061174 1.0020146369934082 0.9983502626419067 1.0001165866851807 1.001772403717041 1.0022255182266235 0.9979700446128845
5 1.2686746697872877 0.9969125986099243 0.9947887659072876 1.0005877017974854 1.002645492553711 0.9965042471885681 0.9939157962799072
6 1.2686995174735785 0.991974413394928 0.9906934499740601 1.0015556812286377 1.0029761791229248 0.9909098148345947 0.9893287420272827
7 1.2669328823685646 0.986285924911499 0.986458420753479 1.00163

61 1.2781321071088314 0.8941566944122314 1.0158065557479858 0.9577260613441467 1.0063345432281494 0.8870934247970581 1.0168590545654297
62 1.2945908978581429 0.8891575336456299 1.0190646648406982 0.9542859196662903 1.0065827369689941 0.8819210529327393 1.020451545715332
63 1.2589241284877062 0.8828964829444885 1.023069977760315 0.9496189951896667 1.0071232318878174 0.8754829168319702 1.0248419046401978
64 1.2613478247076273 0.8767493367195129 1.0268070697784424 0.9443910717964172 1.0075404644012451 0.8692336082458496 1.0289478302001953
65 1.2607248164713383 0.870589554309845 1.0320067405700684 0.9406582117080688 1.0087101459503174 0.8628041744232178 1.034595251083374
66 1.2522711344063282 0.8634965419769287 1.0352818965911865 0.9335472583770752 1.0089163780212402 0.8557131886482239 1.038211464881897
67 1.260466292500496 0.8561655879020691 1.0419206619262695 0.928678572177887 1.0103424787521362 0.8481086492538452 1.0454293489456177
68 1.2570740878582 0.8482930660247803 1.045260310173034