# Sovling the Maxwell's equation using the attention mechnism

We use both the encoder and decoder in transformer, published in paper 

Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information Processing Systems, pages 6000-6010.

The encoder is used as an embeding block, which converts the 3D rho and J vector at a certain time, e.g., shape [batch_size, num_steps, 8, nx ,ny, nz], into **a** embeded_source_vector. Here, the self-attention mechnism is used. **Note that this procedure can be replaced by the 3D convolutional layers in principle.**

Similarly, the corresponding electormagnetic fields of shape [batch_size, num_steps, 10, nx ,ny, nz] are converted into **a** embeded_EM_vector.

## 1. Define the package to use

In [1]:
import os, math
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from torch.nn.parallel import DataParallel
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


## 2. The JefiAtten blokcs

### 2.1 Self-attention and Cross-attention

In [5]:
class MultiLayerSelfAttention(nn.Module):
    def __init__(self, query_EM_dim, mem_source_dim, embeded_dim, num_heads, num_layers,dropout):
   
        super(MultiLayerSelfAttention, self).__init__()
    
        self.num_heads = num_heads
        self.embeded_dim = embeded_dim
        self.MHembeding = self.num_heads*self.embeded_dim
        
        self.self_atten_linear_1 = nn.Linear(embeded_dim, embeded_dim * 4)
        self.self_atten_linear_2 = nn.Linear(embeded_dim * 4, embeded_dim)
        self.relu = nn.ReLU()
        
        self.sqrt_d = torch.sqrt(torch.tensor(embeded_dim, dtype=torch.float32))
        self.num_layers = num_layers
        
        self.layer_norm_1 = nn.LayerNorm(self.embeded_dim)
        self.layer_norm_2 = nn.LayerNorm(self.embeded_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, n_batch):
        scores = torch.einsum('bijkhd,bijkhd->bijkh', x, x) / self.sqrt_d
        # softmax at the dimensions [i,j,k]
        weights = torch.softmax(scores.view(n_batch, -1 ,self.num_heads), dim=1).view_as(scores)
        
        x = x + self.dropout(torch.einsum('bijkh,bijkhd->bijkhd', weights, x))
        x = x + self.dropout(self.self_atten_linear_2(self.relu(self.self_atten_linear_1(x))))
        
        return x
     
    
class MultiLayerAttention(nn.Module):
    def __init__(self, query_EM_dim, mem_source_dim, embeded_dim, num_heads, num_layers, dropout):
   
        super(MultiLayerAttention, self).__init__()
    
        self.num_heads = num_heads
        self.embeded_dim = embeded_dim
        self.MHembeding = self.num_heads*self.embeded_dim
        
        self.atten_linear_1 = nn.Linear(embeded_dim, embeded_dim * 4)
        self.atten_linear_2 = nn.Linear(embeded_dim * 4, embeded_dim)
        self.relu = nn.ReLU()
        
        self.linear = nn.Linear(self.MHembeding, embeded_dim)
        self.sqrt_d = torch.sqrt(torch.tensor(embeded_dim, dtype=torch.float32))
        self.num_layers = num_layers
        
        self.layer_norm_1 = nn.LayerNorm(self.embeded_dim)
        self.layer_norm_2 = nn.LayerNorm(self.embeded_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, k, q, v, n_batch):
        # perform attention along the [nx, ny, nz] dimension
        scores = torch.einsum('bijkhd,bijkhd->bijkh', q, k) / self.sqrt_d
        # softmax at the dimensions [i,j,k]
        weights = torch.softmax(scores.reshape(n_batch, -1, self.num_heads), dim=1).view_as(scores)
        v = v + self.dropout(torch.einsum('bijkh,bijkhd->bijkhd', weights, v))
        v = self.layer_norm_2(v + self.dropout(self.atten_linear_2(self.relu(self.atten_linear_1(v)))))
        
        return v

### 2.2 Time-attention 

In [None]:
class TimeSelfAttention(nn.Module):
    def __init__(self, rho_J_dim, EM_dim, embeded_dim, dropout):
   
        super(TimeSelfAttention, self).__init__()
        
        self.embeded_dim = embeded_dim
        self.rho_J_dim = rho_J_dim
        self.EM_dim = EM_dim
        
        self.self_atten_linear_1 = nn.Linear(embeded_dim, embeded_dim * 4)
        self.self_atten_linear_2 = nn.Linear(embeded_dim * 4, embeded_dim)
        self.relu = nn.ReLU()
        
        self.sqrt_d = torch.sqrt(torch.tensor(self.embeded_dim, dtype=torch.float32))
        
        self.layer_norm_1 = nn.LayerNorm(self.embeded_dim)
        self.layer_norm_2 = nn.LayerNorm(self.embeded_dim)     
        
        self.linear_rho = nn.Linear(self.embeded_dim * self.rho_J_dim, self.rho_J_dim)
        self.linear_em = nn.Linear(self.embeded_dim * self.EM_dim, self.EM_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, n_batch):
        
        scores = torch.einsum('bijkht,bijkht->bijkh', x, x) / self.sqrt_d
        weights = torch.softmax(scores.view(n_batch, -1 ,self.embeded_dim), dim=1).view_as(scores)
        
        x = x + self.dropout(torch.einsum('bijkh,bijkht->bijkht', weights, x))
        x = x + self.dropout(self.self_atten_linear_2(self.relu(self.self_atten_linear_1(x))))
        if x.shape[4] == self.rho_J_dim:
            out = self.linear_rho(x.reshape(*x.shape[:-2], -1))
        else:
            out = self.linear_em(x.reshape(*x.shape[:-2], -1))
        return out  

### 2.3 Local-attention 

In [6]:
class LocalrhoSelfAttention(nn.Module):
    def __init__(self, rho_J_dim, dropout):
   
        super(LocalrhoSelfAttention, self).__init__()
        
        self.embeded_dim = rho_J_dim
        
        self.self_atten_linear_1 = nn.Linear(self.embeded_dim, self.embeded_dim * 4)
        self.self_atten_linear_2 = nn.Linear(self.embeded_dim * 4, self.embeded_dim)
        self.relu = nn.ReLU()
        
        self.sqrt_d = torch.sqrt(torch.tensor(self.embeded_dim, dtype=torch.float32))
        
        self.layer_norm_1 = nn.LayerNorm(self.embeded_dim)
        self.layer_norm_2 = nn.LayerNorm(self.embeded_dim)     

        self.dropout = nn.Dropout(dropout)
        
    def forward(self, k,q,v, n_batch):
        
        scores = torch.einsum('bijkh,bijkh->bijkh', q, k) / self.sqrt_d
        weights = torch.softmax(scores.view(n_batch, -1 ,self.embeded_dim), dim=1).view_as(scores)
        
        v = v + self.dropout(torch.einsum('bijkh,bijkh->bijkh', weights, v))
        v = v + self.dropout(self.self_atten_linear_2(self.relu(self.self_atten_linear_1(v))))
        return v 
    
class LocalemSelfAttention(nn.Module):
    def __init__(self, EM_dim, dropout):
   
        super(LocalemSelfAttention, self).__init__()
        
        self.embeded_dim = EM_dim
        
        self.self_atten_linear_1 = nn.Linear(self.embeded_dim, self.embeded_dim * 4)
        self.self_atten_linear_2 = nn.Linear(self.embeded_dim * 4, self.embeded_dim)
        self.relu = nn.ReLU()
        
        self.sqrt_d = torch.sqrt(torch.tensor(self.embeded_dim, dtype=torch.float32))
        
        self.layer_norm_1 = nn.LayerNorm(self.embeded_dim)
        self.layer_norm_2 = nn.LayerNorm(self.embeded_dim)     
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, k,q,v, n_batch):
        
        scores = torch.einsum('bijkh,bijkh->bijkh', q, k) / self.sqrt_d
        weights = torch.softmax(scores.view(n_batch, -1 ,self.embeded_dim), dim=1).view_as(scores)
        
        v = v + self.dropout(torch.einsum('bijkh,bijkh->bijkh', weights, v))
        v = v + self.dropout(self.self_atten_linear_2(self.relu(self.self_atten_linear_1(v))))
        return v    

In [7]:
x_grid_size_o, y_grid_size_o, z_grid_size_o = 20,20,20
x_grid_size_s, y_grid_size_s, z_grid_size_s = 20,20,20
dx_s, dy_s, dz_s, x_left_boundary_s, y_left_boundary_s, z_left_boundary_s = \
                       6/x_grid_size_s, 6/y_grid_size_s, 6/z_grid_size_s, -3, -3, -3
dt = 0.005

ix = torch.zeros((20,20,20))
iy = torch.zeros((20,20,20))
iz = torch.zeros((20,20,20))
for i in range(20):
    for j in range(20):
        for k in range(20):
            ix[i][j][k] = -2.85+(dx_s*i)
            iy[i][j][k] = -2.85+(dy_s*i)
            iz[i][j][k] = -2.85+(dz_s*i)
local_tensor = torch.zeros(1,3,20,20,20)
for i in range(1):
    local_tensor[i][0] = ix
    local_tensor[i][1] = iy
    local_tensor[i][2] = iz
local_tensor_use = np.transpose(local_tensor,(0,2,3,4,1)).reshape(1,20,20,20,3).to(torch.device('cuda:0'))

### 2.4 Subject of the procedure

In [8]:
class MultiLayerAttentionEM(nn.Module):
    def __init__(self, query_EM_dim, mem_source_dim, time_dim, embeded_dim, output_EM_dim, num_heads, num_layers, dropout):
   
        super(MultiLayerAttentionEM, self).__init__()
    
        self.num_heads = num_heads
        self.embeded_dim = embeded_dim
        self.output_EM_dim = output_EM_dim
        self.MHembeding = self.num_heads*self.embeded_dim
    
        self.query = nn.Linear(query_EM_dim, self.MHembeding)
        
        self.key = nn.Linear(mem_source_dim, self.MHembeding)
        
        self.value = nn.Linear(query_EM_dim, self.MHembeding)
        
        self.time = nn.Linear(time_dim, time_dim)
        
        self.local_rho = nn.Linear(3, mem_source_dim)
        self.local_em = nn.Linear(3, query_EM_dim)
        
        self.num_layers = num_layers
        
        self.dropout = nn.Dropout(dropout)
        
        '''
        Check if the code is effective   
        '''
        self.self_atten_k = nn.Sequential(*[MultiLayerSelfAttention(query_EM_dim, mem_source_dim,\
                                                     embeded_dim, num_heads, num_layers, dropout)\
                             for _ in range(self.num_layers)]).to(torch.device('cuda:0'))
        self.self_atten_q = nn.Sequential(*[MultiLayerSelfAttention(query_EM_dim, mem_source_dim, \
                                                   embeded_dim, num_heads, num_layers, dropout) \
                             for _ in range(self.num_layers)]).to(torch.device('cuda:0'))
        self.self_atten_v = nn.Sequential(*[MultiLayerSelfAttention(query_EM_dim, mem_source_dim, \
                                                   embeded_dim, num_heads, num_layers, dropout) \
                             for _ in range(self.num_layers)]).to(torch.device('cuda:0'))
        self.atten =  nn.Sequential(*[MultiLayerAttention(query_EM_dim, mem_source_dim, \
                                          embeded_dim, num_heads, num_layers, dropout) \
                             for _ in range(self.num_layers)]).to(torch.device('cuda:0'))
        
        self.linear = nn.Linear(embeded_dim * num_heads, output_EM_dim)
        
        self.self_atten_time = TimeSelfAttention(mem_source_dim, query_EM_dim, time_dim, dropout).to(torch.device('cuda:0'))
        
        self.atten_local_rho = LocalrhoSelfAttention(mem_source_dim, dropout).to(torch.device('cuda:0'))
        self.atten_local_em = LocalemSelfAttention(query_EM_dim, dropout).to(torch.device('cuda:0'))
        
    def forward(self, source_inputs, EM_inputs, local_tensor_use):
        # n_t_steps = source_inputs.shape[1]
        n_batch = source_inputs.shape[0]
        
        # First to time_attention bijkdt to bijkd in order to get the time information
        
        source_time_attention = self.self_atten_time(self.time(source_inputs), n_batch)
        EM_time_attention = self.self_atten_time(self.time(EM_inputs), n_batch)
        
        source_local_attention = self.atten_local_rho(source_time_attention,self.local_rho(local_tensor_use),source_time_attention,n_batch)
        em_local_attention = self.atten_local_em(EM_time_attention,self.local_em(local_tensor_use),EM_time_attention,n_batch)
        # print("source_time_attention:",source_time_attention.shape)
        # print("EM_time_attention:" ,EM_time_attention.shape)
        
        # divide the last dimension into multi-heads
        q = self.query(em_local_attention).view(*em_local_attention.shape[:-1], self.num_heads, self.embeded_dim)
        k = self.key(source_local_attention).view(*source_local_attention.shape[:-1], self.num_heads, self.embeded_dim)
        v = self.value(em_local_attention).view(*em_local_attention.shape[:-1], self.num_heads, self.embeded_dim)
        # print(q.shape,k.shape,v.shape)
        # =======================================================================
        # perform multi-layer self-attention along the [nx, ny, nz] dimension for q, k, v
        for i_layer in range(self.num_layers):
            q = self.self_atten_k[i_layer](q, n_batch)
            k = self.self_atten_q[i_layer](k, n_batch)
            
        
        # =======================================================================
        for i_layer in range(self.num_layers):
            v = self.self_atten_v[i_layer](v, n_batch)
            v = self.atten[i_layer](k, q, v, n_batch)    
        v = self.linear(v.reshape(*v.shape[:-2], -1))
        
        output = v
        
        return output

# encoder for rho and J
# transformer_encoder_EM = MultiLayerAttentionEM(6, 4, 10, 16, 6, 16, 1, 0.1).to(torch.device('cuda:0'))
# transformer_encoder_EM = DataParallel(transformer_encoder_EM, device_ids=list(range(1)))

# src_EM = torch.randn(1, nx, ny, nz, 6, 10).to(torch.device('cuda:0'))
# tgt = torch.randn(1, nx, ny, nz, 4, 10).to(torch.device('cuda:0'))
# output_EM = transformer_encoder_EM(tgt, src_EM,local_tensor_use)
# print(tgt.shape, src_EM.shape, output_EM.shape)

## 3. Optimizer

In [11]:
from numpy import random
from torch.optim import lr_scheduler
import matplotlib.pyplot as plt
index_number = np.arange(25)
index_number = random.randint(500, size=(10))
index_number

array([444, 207, 371,  93, 134, 498, 276,  31,  32, 169])

In [13]:
class Trainer:
    def __init__(self, model_G, local_tensor_use, batch_size, num_gpus, num_workers):
        self.model_G = model_G 
        self.batch_size = batch_size
        self.num_gpus = num_gpus
        self.num_workers = num_workers
        self.local_tensor_use = local_tensor_use
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        # self.Hubercriterion = nn.MSELoss()
        # self.Hubercriterion = nn.KLDivLoss(reduction='sum')
        self.Hubercriterion = nn.SmoothL1Loss()
    def train(self, num_epochs, num_epochs1, learning_rate_G):
        # 训练的并行化，通过设置gpu数量来进行并行运算，这里也可以通过设置GPU名称来指定部分GPU进行计算
        self.model_G = self.model_G.to(self.device)  

        if self.num_gpus > 1:
            self.model_G = DataParallel(self.model_G, device_ids=list(range(self.num_gpus)))
        
        optimizer_G = optim.Adam(self.model_G.parameters(), lr=learning_rate_G)

        # scheduler = lr_scheduler.StepLR(optimizer_G, step_size=25, gamma = 0.1)
        running_loss_epoch_text = torch.zeros((num_epochs+num_epochs1))
        copy_tensor = torch.zeros(32,7,7,7).to(self.device) 

        for time in range(2):
            if time==0:
                use_epochs = num_epochs
            else:
                use_epochs = num_epochs1
            for epoch in range(use_epochs):
                generator_loss_s = 0.0
                discriminator_loss_s = 0.0
                number_s = 0
                for j in range(4):
                    if time==0:
                        dataset = MyDataset(190+j)
                    else:
                        dataset = MyDataset1(190+j)
                        self.batch_size = 1
                    self.train_loader = DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers )
                    for i, (input1, input2, label) in enumerate(self.train_loader):
                        if i < 10:
                            continue

                        input1, input2, label = input1.to(self.device).reshape(self.batch_size,20,20,20,4,10), \
                                                input2.to(self.device).reshape(self.batch_size,20,20,20,6,10), \
                                                label.to(self.device)  
                        
                        if time == 0:
                            generated_fake = self.model_G(input1, input2,self.local_tensor_use)
                        else:
                            if i>960:
                                middle = input1_last.permute((5,0,1,2,3,4))
                                input3 = torch.cat([middle[1:], middle[:1]], dim=0).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
                                input3 = input3.permute((1,2,3,4,5,0))
                                input3.permute((5,0,1,2,3,4))[9] = generated_fake.data
                                input1_last = input3
                                generated_fake =  self.model_G(input1, input3,self.local_tensor_use)
                            else:
                                generated_fake =  self.model_G(input1, input2,self.local_tensor_use)
                                input1_last = input2
                            if i%10 == 0 and i!=0:
                                input1_last = input2        
                        generator_loss = self.Hubercriterion(generated_fake, label)          
                        optimizer_G.zero_grad()
                        generator_loss.backward()
                        optimizer_G.step()
                        if i % 500 == 0 and number_s > 1:
                            print('[%d, %d, %d] generator_loss: %.5f; ' % (time, epoch + 1, j, generator_loss_s / number_s ))

                        generator_loss_s += generator_loss.item()
                        number_s += 1
                if time==0:
                    running_loss_epoch_text[epoch] = generator_loss_s
                else:
                    running_loss_epoch_text[num_epochs+epoch] = generator_loss_s 
                      
        return running_loss_epoch_text, self.model_G      

In [14]:
class MyDataset(Dataset):
    def __init__(self, num):
        self.inputs1 = np.load('/data/zhangjunjie/Data/Transfomer-rho-J/10_data_rho_J-'+str(num)+'.npy',allow_pickle=True)
        self.inputs2 = np.load('/data/zhangjunjie/Data/Transfomer-EM/10_data_EM-'+str(num)+'.npy',allow_pickle=True)
        self.labels = np.load('/data/zhangjunjie/Data/Transfomer-label/10_label-'+str(num)+'.npy',allow_pickle=True)
    def __getitem__(self, index):
        input1 = torch.Tensor(self.inputs1[index])
        input2 = torch.Tensor(self.inputs2[index])
        label = torch.Tensor(self.labels[index])
        return input1, input2, label

    def __len__(self):
        return len(self.labels)

In [14]:
class MyDataset1(Dataset):
    def __init__(self, num):
        self.inputs1 = np.load('/data/zhangjunjie/Data/Transfomer-rho-J/10_data_rho_J-'+str(num)+'.npy',allow_pickle=True)
        self.inputs2 = np.load('/data/zhangjunjie/Data/Transfomer-EM/10_data_EM-'+str(num)+'.npy',allow_pickle=True)
        self.labels = np.load('/data/zhangjunjie/Data/Transfomer-label/10_label-'+str(num)+'.npy',allow_pickle=True)

    def __getitem__(self, index):
        input1 = torch.Tensor(self.inputs1[index])
        input2 = torch.Tensor(self.inputs2[index])
        label = torch.Tensor(self.labels[index])
        return input1, input2, label

    def __len__(self):
        return len(self.labels)

In [None]:
batch_size, query_EM_dim,  mem_source_dim, time_dim, embeded_dim, output_EM_dim, num_heads, num_layers, dropout =1, 6, 4, 10, 16, 6, 4, 1, 0.2


model_G = MultiLayerAttentionEM(query_EM_dim, mem_source_dim, time_dim, embeded_dim, output_EM_dim, num_heads, num_layers,dropout).to(device)
# model_G = torch.load('20231113_200_epochs.pth')
# model_G.eval();
import time
st = time.time()
trainer = Trainer(model_G, local_tensor_use, batch_size = batch_size,  num_gpus = 1, num_workers = 8)
running_loss_epoch_text, model_S = trainer.train(num_epochs=50, num_epochs1=0, learning_rate_G = 0.0001)
print((time.time() - st)/3600)

In [21]:
x_grid_size_o, y_grid_size_o, z_grid_size_o = 20,20,20
x_grid_size_s, y_grid_size_s, z_grid_size_s = 20,20,20
dx_s, dy_s, dz_s, x_left_boundary_s, y_left_boundary_s, z_left_boundary_s = \
                       6/x_grid_size_s, 6/y_grid_size_s, 6/z_grid_size_s, -3, -3, -3
dt = 0.005

ix = torch.zeros((20,20,20))
iy = torch.zeros((20,20,20))
iz = torch.zeros((20,20,20))
for i in range(20):
    for j in range(20):
        for k in range(20):
            ix[i][j][k] = -2.85+(dx_s*i)
            iy[i][j][k] = -2.85+(dy_s*i)
            iz[i][j][k] = -2.85+(dz_s*i)
local_tensor = torch.zeros(1,3,20,20,20)
for i in range(1):
    local_tensor[i][0] = ix
    local_tensor[i][1] = iy
    local_tensor[i][2] = iz
local_tensor_use = np.transpose(local_tensor,(0,2,3,4,1)).reshape(1,20,20,20,3).to(torch.device('cuda:0'))

In [None]:
import matplotlib.pyplot as plt
xi, yi = np.mgrid[1:21:1,1:21:1]
fig1, axes1 = plt.subplots(ncols=6, nrows=3, figsize = (15,7))
lossfunction = nn.SmoothL1Loss()
# lossfunction1 = nn.MSELoss()

dataset = MyDataset(194)
n_batch = 1
val_loader = DataLoader(dataset=dataset, batch_size=n_batch, shuffle=True, num_workers=1)
generator_loss = 0.0
generator_loss_all = 0.0
label_plot = torch.zeros((6,20,20))
out_plot = torch.zeros((6,20,20))
loss = torch.zeros((6,20,20))
loss_all = torch.zeros(1920)
input_middle = torch.zeros((n_batch,20,20,20,9,10))
number = 1000
number1 = 0
with torch.no_grad():
    for x, (input1, input2, label) in enumerate(val_loader):
            input_text, input1_text = input1.to(torch.device('cuda:0')).reshape(n_batch,20,20,20,4,10) ,\
            input2.to(torch.device('cuda:0')).reshape(n_batch,20,20,20,6,10)
            label_text = label.to(torch.device('cuda:0'))
            
            if x<800:
                continue
            # print(input_text.shape,input1_text.shape)  
            if x > 960:
                middle = input1_last.permute((5,0,1,2,3,4))
                input3 = torch.cat([middle[1:], middle[:1]], dim=0).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
                input3 = input3.permute((1,2,3,4,5,0))
                input3.permute((5,0,1,2,3,4))[9] = generated_fake.data
                input1_last = input3
                generated_fake = model_S(input_text, input3, local_tensor_use)
            else:
                generated_fake = model_S(input_text, input1_text, local_tensor_use)
                input1_last = input1_text
                
            if i%10 ==0 or i!=0:
                input1_last = input1_text
            generated_fake = model_S(input_text, input1_text, local_tensor_use)    
            loss_all[x] = lossfunction(generated_fake, label_text)
            
            if x == number or x == number+1 or x == number+2 or x == number+3 or x == number+4 or x == number+5:  
                m = x%number
                # print(m)
            # for m in range(6):
                for i in range(20):
                    for j in range(20):
                        label_plot[m][i][j] = label_text[0][5][i][j][2]  #m+number1
                        out_plot[m][i][j] = generated_fake[0][5][i][j][2]
                        loss[m][i][j] = lossfunction(label_plot[m][i][j],out_plot[m][i][j])
                # print(loss.shape)
                axes1[0,m].pcolormesh(xi, yi, label_plot[m].detach().numpy(),cmap ='viridis')
                axes1[1,m].pcolormesh(xi, yi, out_plot[m].detach().numpy(),cmap ='viridis')
                axes1[2,m].pcolormesh(xi, yi, loss[m].detach().numpy(),cmap ='RdBu')
