In [1]:
import os
import numpy as np
import netCDF4 as nc
import xarray as xr
import matplotlib.pyplot as plt
import dask
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch import nn
from tqdm import tqdm
from torch import optim

from dataset import *
from model import *
from metrics import *

import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler
import time
from carbontracker.tracker import CarbonTracker  # Monitoring carbon footprint
from carbontracker import parser

## Data Lading

In [2]:
def get_dataset(args):
    dataset = args['dataset']
    img_path = dataset['img_path']
    transform = transforms.Compose([transforms.ToTensor()])
    # Load input and reference data from .npy files :
    train_data = Interpolated_Img_Dataset(img_path, 'normalized_input_physical_data_1998_2015.npy', 'ln_Chl_ref_norm_1998_2015.npy', 
                                        transform=transform, normalize=False) 

    return train_data

def get_loader(train_data, args):

    loader = args['dataloader']
    # Define train, validation and test datasets
    # train:2003-2010, val:1998-2001, test:2012-2015
    train_idx, val_idx, test_idx = [i for i in range(60,156)], [i for i in range(0,48)], [i for i in range(168,216)] 

    train_sampler = SubsetRandomSampler(train_idx) 
    val_sampler = SubsetRandomSampler(val_idx)
    test_sampler = SubsetRandomSampler(test_idx)

    # convert to data loaders :
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=loader['train_bs'],
        sampler=train_sampler, num_workers=loader['num_workers'])
    val_loader = torch.utils.data.DataLoader(train_data, batch_size=loader['val_bs'], 
        sampler=val_sampler, num_workers=loader['num_workers'])
    test_loader = torch.utils.data.DataLoader(train_data, batch_size=loader['test_bs'], 
        sampler=test_sampler, num_workers=loader['num_workers'])

    return train_loader, val_loader, test_loader

In [3]:
import yaml

config_file = './config/Chl_CNN.yaml'
with open(config_file, 'r') as f:
    args = yaml.safe_load(f)

train_data = get_dataset(args)
train_loader, val_loader, test_loader = get_loader(train_data, args)

print('len(train_loader): ', len(train_loader))
print('len(validate_loader)', len(val_loader))
print('len(test_loader)', len(test_loader))

for data, target in train_loader:
    print('input data and target:')
    print(data.shape, target.shape)
    break

len(train_loader):  48
len(validate_loader) 24
len(test_loader) 24
input data and target:
torch.Size([2, 9, 178, 358]) torch.Size([2, 1, 178, 358])


## Get model

In [4]:
device = torch.device(args['device'] if torch.cuda.is_available() else "cpu")
n_epochs = args['epochs']

opt = args['optimizer']['name']
lr = args['optimizer']['lr']

In [5]:
# model

model_M1 = CNN_M1()
model_M2 = CNN_M2()
model_M3 = CNN_M3()
model_M4 = CNN_M4()
model_M5 = CNN_M5()
model_M6 = CNN_M6()
model_M7 = CNN_M7()
model_M8 = CNN_M8()
model_W = CNN_W()

# 将其参数和输入数据转换为双精度浮点数
model_M1.to(device)  
model_M2.to(device)  
model_M3.to(device)  
model_M4.to(device)   
model_M5.to(device)    
model_M6.to(device)
model_M7.to(device)   
model_M8.to(device) 
model_W.to(device) 

model_1 = model_M1.double()     
model_2 = model_M2.double()
model_3 = model_M3.double()
model_4 = model_M4.double()
model_5 = model_M5.double()   
model_6 = model_M6.double()
model_7 = model_M7.double()
model_8 = model_M8.double()
model_W = model_W.double()

# specify loss function 
criterion = torch.nn.MSELoss()

# specify optimizer
optimizer_1 = getattr(optim, opt)(model_1.parameters(), lr=lr)
optimizer_2 = getattr(optim, opt)(model_2.parameters(), lr=lr)
optimizer_3 = getattr(optim, opt)(model_3.parameters(), lr=lr)
optimizer_4 = getattr(optim, opt)(model_4.parameters(), lr=lr)
optimizer_5 = getattr(optim, opt)(model_5.parameters(), lr=lr)
optimizer_6 = getattr(optim, opt)(model_6.parameters(), lr=lr)
optimizer_7 = getattr(optim, opt)(model_7.parameters(), lr=lr)
optimizer_8 = getattr(optim, opt)(model_8.parameters(), lr=lr)
optimizer_W = getattr(optim, opt)(model_W.parameters(), lr=lr)

scheduler1 = torch.optim.lr_scheduler.MultiStepLR(optimizer_1, milestones=[400], gamma=0.1)
scheduler2 = torch.optim.lr_scheduler.MultiStepLR(optimizer_2, milestones=[400], gamma=0.1)
scheduler3 = torch.optim.lr_scheduler.MultiStepLR(optimizer_3, milestones=[400], gamma=0.1)
scheduler4 = torch.optim.lr_scheduler.MultiStepLR(optimizer_4, milestones=[400], gamma=0.1)
scheduler5 = torch.optim.lr_scheduler.MultiStepLR(optimizer_5, milestones=[400], gamma=0.1)
scheduler6 = torch.optim.lr_scheduler.MultiStepLR(optimizer_6, milestones=[400], gamma=0.1)
scheduler7 = torch.optim.lr_scheduler.MultiStepLR(optimizer_7, milestones=[400], gamma=0.1)
scheduler8 = torch.optim.lr_scheduler.MultiStepLR(optimizer_8, milestones=[400], gamma=0.1)
schedulerW = torch.optim.lr_scheduler.MultiStepLR(optimizer_W, milestones=[400], gamma=0.1)

print('Definition of 8 sub-models Mi with the architecture :') 
print(model_1)
print('and one attention module W that outputs 8 weighted maps :')
print(model_W)


Definition of 8 sub-models Mi with the architecture :
CNN_M1(
  (conv1): Conv2d(9, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv5): Conv2d(128, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (dropout): Dropout(p=0.35, inplace=False)
)
and one attention module W that outputs 8 weighted maps :
CNN_W(
  (conv1): Conv2d(9, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(32, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (dropout): Dropout(p=0.35, inplace=False)
)


In [6]:
## Specify a folder name in which the trained odels will be saved, as well as their root name :
folder_name = './checkpoint/'
model_name = 'MMCNN'

folder_name2 = folder_name + model_name + '/'
model_save_1 = folder_name + model_name + '_1.pt'
model_save_2 = folder_name + model_name + '_2.pt'
model_save_3 = folder_name + model_name + '_3.pt'
model_save_4 = folder_name + model_name + '_4.pt'
model_save_5 = folder_name + model_name + '_5.pt'
model_save_6 = folder_name + model_name + '_6.pt'
model_save_7 = folder_name + model_name + '_7.pt'
model_save_8 = folder_name + model_name + '_8.pt'
model_save_W = folder_name + model_name + '_W.pt'

## Train

In [7]:
#######################
# Monitoring metrics during training loop :
loss_mask = 0
loss_fig = 10
loss_fig_epoque_affichage = 10  # loss绘图间隔的epoch

# Import continental mask to compute the loss only on the ocean (not on the land)
img_path = args['dataset']['img_path']
mask = torch.load(os.path.join(img_path, 'inter_mask.pt'))
mask = mask.bool()
mask = mask.reshape([1,1,178,358])
mask.to(device)

batch_size = args['dataloader']['train_bs']


# Check for time computation
tps_ini = time.process_time() 


torch.set_default_dtype(torch.float64)
# device = torch.device("cuda:1")s

# Create output folder to save the data
if not os.path.exists(folder_name2):
    os.makedirs(folder_name2)

In [9]:
#############################################
######### Weighted neural network  ##########  #### 8 modes
#############################################



#######################
# Training loop :
#######################

# Initialisation of the loss value
train_losses_save, valid_losses_save = [], []

valid_loss_min = np.Inf # track change in validation loss

# from carbontracker import parser
for epoch in range(1, n_epochs+1):
    
    # keep track of training and validation loss
    train_loss = 0.0
    valid_loss = 0.0
    
    compteur = 0  # 计数器
        
    ###################
    # train the model #
    ###################
        
    model_1.train()
    model_2.train()
    model_3.train()
    model_4.train()
    model_5.train()
    model_6.train()
    model_7.train()
    model_8.train()
    model_W.train()
        
    for data, target in train_loader:

        data,target=data.to(device,dtype=torch.float64),target.to(device,dtype=torch.float64)
        # clear the gradients of all optimized variables
        optimizer_1.zero_grad()
        optimizer_2.zero_grad()
        optimizer_3.zero_grad()
        optimizer_4.zero_grad()
        optimizer_5.zero_grad()
        optimizer_6.zero_grad()
        optimizer_7.zero_grad()
        optimizer_8.zero_grad()
        optimizer_W.zero_grad()
            
        ################################################################
        # Calculation of the final output as a combination of the 8 sub-models
        ################################################################
            
        output_W = model_W(data.double())  # [bs, 1, 178, 358]
        output_1 = model_1(data.double()) 
        output_2 = model_2(data.double()) 
        output_3 = model_3(data.double()) 
        output_4 = model_4(data.double()) 
        output_5 = model_5(data.double()) 
        output_6 = model_6(data.double()) 
        output_7 = model_7(data.double()) 
        output_8 = model_8(data.double()) 
            
        # Weighted average
        
        W1 = output_W[:,0,:,:]
        W1 = torch.reshape(W1, (batch_size, 1,178, 358))
        W2 = output_W[:,1,:,:]
        W2 = torch.reshape(W2, (batch_size, 1,178, 358))
        W3 = output_W[:,2,:,:]
        W3 = torch.reshape(W3, (batch_size, 1,178, 358))
        W4 = output_W[:,3,:,:]
        W4 = torch.reshape(W4, (batch_size, 1,178, 358))
        W5 = output_W[:,4,:,:]
        W5 = torch.reshape(W5, (batch_size, 1,178, 358))
        W6 = output_W[:,5,:,:]
        W6 = torch.reshape(W6, (batch_size, 1,178, 358))
        W7 = output_W[:,6,:,:]
        W7 = torch.reshape(W7, (batch_size, 1,178, 358))
        W8 = output_W[:,7,:,:]
        W8 = torch.reshape(W8, (batch_size, 1,178, 358))
        
        output_1 = torch.mul(output_1, W1)  # [bs, 1, 178, 358]
        output_2 = torch.mul(output_2, W2)
        output_3 = torch.mul(output_3, W3)
        output_4 = torch.mul(output_4, W4)
        output_5 = torch.mul(output_5, W5)
        output_6 = torch.mul(output_6, W6)
        output_7 = torch.mul(output_7, W7)
        output_8 = torch.mul(output_8, W8)
            
        # Concatenation
        concat = torch.cat((output_1, output_2, output_3, output_4, output_5, output_6, output_7, output_8), 1) 
            
        # Sum on the last dimension
        output = torch.sum(concat, 1)  # [bs, 178, 358]
        output = torch.reshape(output, (batch_size, 1,178, 358)) # Reshape to apply the continental mask
            
        ######################################################
        # Loss computation
        ######################################################
        
        compteur = compteur + 1
        
        output2 = output
        target2 = target
        
        where_are_NaNs = np.isnan(target.cpu())  #  NaN 值的位置为 True，其他位置为 False
        
        #  Training only on 50N - 50S(排除极区数据，仅在赤道附近)
        
        where_are_NaNs[:,0,0:39,:] = 1
        where_are_NaNs[:,0,139:178,:] = 1
        
        where_are_NaNs = where_are_NaNs.bool()
        where_are_NaNs = ~where_are_NaNs
        where_are_NaNs = where_are_NaNs.to(device)
        
        output2 = torch.masked_select(output2, where_are_NaNs)
        target2 = torch.masked_select(target2, where_are_NaNs)
                 
        # calculate the batch loss
        loss = criterion(output2.double(), target2.double())

        ######################################################
        # backward pass: compute gradient of the loss with respect to model parameters
        ######################################################
        
        loss.backward()
        # perform a single optimization step (parameter update)
        optimizer_1.step()
        optimizer_2.step()
        optimizer_3.step()
        optimizer_4.step()
        optimizer_5.step()
        optimizer_6.step()
        optimizer_7.step()
        optimizer_8.step()
        optimizer_W.step()
        
        # updatse training los
        train_loss += loss.item()
           
    scheduler1.step()
    scheduler2.step()
    scheduler3.step()
    scheduler4.step()
    scheduler5.step()
    scheduler6.step()
    scheduler7.step()
    scheduler8.step()
    schedulerW.step()
    
    model_1.eval()
    model_2.eval()
    model_3.eval()
    model_4.eval()
    model_5.eval()
    model_6.eval()
    model_7.eval()
    model_8.eval()
    model_W.eval()
    
    for data, target in val_loader:

        data,target=data.to(device,dtype=torch.float),target.to(device,dtype=torch.float)
            
        output_W = model_W(data.double()) 
        output_1 = model_1(data.double()) 
        output_2 = model_2(data.double()) 
        output_3 = model_3(data.double()) 
        output_4 = model_4(data.double())
        output_5 = model_5(data.double()) 
        output_6 = model_6(data.double()) 
        output_7 = model_7(data.double())
        output_8 = model_8(data.double())
            
        
        W1 = output_W[:,0,:,:]
        W1 = torch.reshape(W1, (batch_size, 1,178, 358))
        W2 = output_W[:,1,:,:]
        W2 = torch.reshape(W2, (batch_size, 1,178, 358))
        W3 = output_W[:,2,:,:]
        W3 = torch.reshape(W3, (batch_size, 1,178, 358))
        W4 = output_W[:,3,:,:]
        W4 = torch.reshape(W4, (batch_size, 1,178, 358))
        W5 = output_W[:,4,:,:]
        W5 = torch.reshape(W5, (batch_size, 1,178, 358))
        W6 = output_W[:,5,:,:]
        W6 = torch.reshape(W6, (batch_size, 1,178, 358))
        W7 = output_W[:,6,:,:]
        W7 = torch.reshape(W7, (batch_size, 1,178, 358))
        W8 = output_W[:,7,:,:]
        W8 = torch.reshape(W8, (batch_size, 1,178, 358))
        
        output_1 = torch.mul(output_1, W1)
        output_2 = torch.mul(output_2, W2)
        output_3 = torch.mul(output_3, W3)
        output_4 = torch.mul(output_4, W4)
        output_5 = torch.mul(output_5, W5)
        output_6 = torch.mul(output_6, W6)
        output_7 = torch.mul(output_7, W7)
        output_8 = torch.mul(output_8, W8)

        concat = torch.cat((output_1, output_2, output_3, output_4, output_5, output_6, output_7, output_8), 1)   
        output = torch.sum(concat, 1)
        output = torch.reshape(output, (batch_size, 1,178, 358)) 
          
        output2 = output
        target2 = target
        
        where_are_NaNs = np.isnan(target.cpu())
        
        where_are_NaNs[:,0,0:39,:] = 1
        where_are_NaNs[:,0,139:178,:] = 1
        
        where_are_NaNs = where_are_NaNs.bool()
        where_are_NaNs = ~where_are_NaNs
        where_are_NaNs = where_are_NaNs.to(device)
        
        output2 = torch.masked_select(output2, where_are_NaNs)
        target2 = torch.masked_select(target2, where_are_NaNs)
        

        # calculate the batch loss
        loss = criterion(output2.double(), target2.double())
        # update average validation loss 
        valid_loss += loss.item()
        

    # calculate average losses
    train_loss = train_loss/len(train_loader.sampler)  # 每个epoch
    valid_loss = valid_loss/len(val_loader.sampler)
    
    train_losses_save.append(train_loss)
    valid_losses_save.append(valid_loss)
    
    # Save the loss in files
    np.save(folder_name2 + 'train_losses_save.npy', train_losses_save)
    np.save(folder_name2 + 'valid_losses_save.npy', valid_losses_save)
    
    # print loss/training curves to monitor overfitting / convergence
    if loss_fig ==1 and epoch % loss_fig_epoque_affichage ==0 :
        fig1=plt.figure(figsize=(3,3))
        plt.plot(valid_losses_save, label='Validation loss',color='orange')
        plt.plot(train_losses_save, label='Training loss',color='blue', linestyle='dashed')
        plt.legend(frameon=False)
        plt.show()
        
    # print training/validation statistics 
    tps_t = time.process_time() 
    epoch_time = round((tps_t - tps_ini)/60)
    print('Epoch: {}/{} \tTraining Loss: {:.6f} \tValidation Loss 1: {:.6f} \ttime: {}min\t'.format(epoch, n_epochs, train_loss, valid_loss, epoch_time))      
    
    # save model if validation loss has decreased
    if valid_loss <= valid_loss_min:
        torch.save(model_1.state_dict(), model_save_1 )   
        torch.save(model_2.state_dict(), model_save_2 ) 
        torch.save(model_3.state_dict(), model_save_3 ) 
        torch.save(model_4.state_dict(), model_save_4 ) 
        torch.save(model_5.state_dict(), model_save_5 ) 
        torch.save(model_6.state_dict(), model_save_6 ) 
        torch.save(model_7.state_dict(), model_save_7 ) 
        torch.save(model_8.state_dict(), model_save_8 ) 
        torch.save(model_W.state_dict(), model_save_W ) 
        
        valid_loss_min = valid_loss

    
tps_final = time.process_time() 
tps_final_min = round((tps_final - tps_ini)/60,2)
print('Time computation : {} min'.format(tps_final_min))

########################################################################
### Check overfitting / save loss
########################################################################

plt.plot(valid_losses_save, label='Validation loss',color='orange')
plt.plot(train_losses_save, label='Training loss',color='blue', linestyle='dashed')

plt.legend(frameon=False)
plt.savefig('./result/' + folder_name2 + 'Loss', dpi= 100,bbox_inches = "tight")
plt.show()


Epoch: 1/50 	Training Loss: 0.157509 	Validation Loss 1: 0.129430 	time: 8min	
Epoch: 2/50 	Training Loss: 0.137176 	Validation Loss 1: 0.117465 	time: 12min	
Epoch: 3/50 	Training Loss: 0.122529 	Validation Loss 1: 0.112572 	time: 17min	
Epoch: 4/50 	Training Loss: 0.113231 	Validation Loss 1: 0.110051 	time: 21min	
Epoch: 5/50 	Training Loss: 0.108636 	Validation Loss 1: 0.101121 	time: 26min	
Epoch: 6/50 	Training Loss: 0.106177 	Validation Loss 1: 0.086066 	time: 31min	
Epoch: 7/50 	Training Loss: 0.097881 	Validation Loss 1: 0.086089 	time: 36min	
Epoch: 8/50 	Training Loss: 0.098116 	Validation Loss 1: 0.088728 	time: 40min	
Epoch: 9/50 	Training Loss: 0.090387 	Validation Loss 1: 0.083048 	time: 44min	
Epoch: 10/50 	Training Loss: 0.089115 	Validation Loss 1: 0.081574 	time: 49min	
Epoch: 11/50 	Training Loss: 0.085814 	Validation Loss 1: 0.074712 	time: 53min	
Epoch: 12/50 	Training Loss: 0.084457 	Validation Loss 1: 0.072972 	time: 58min	
Epoch: 13/50 	Training Loss: 0.082827 

## Test