In [1]:
import os
import torch
from copy import deepcopy
import numpy as np
import xarray as xr
import pandas as pd
import torch.nn as nn
import random
from torch.utils.data import Dataset, DataLoader
import zipfile
import shutil
device = 'cuda' if torch.cuda.is_available() else 'cpu'   
device

'cuda'

In [2]:
def set_seed(seed = 427):
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)

In [3]:
def load_data():
    # CMIP data    
    train = xr.open_dataset('tcdata/enso_round1_train_20210201/CMIP_train.nc')
    label = xr.open_dataset('tcdata/enso_round1_train_20210201/CMIP_label.nc')    
   
    train_sst = train['sst'][:, :12].values  # (4645, 12, 24, 72)
    train_t300 = train['t300'][:, :12].values
    train_ua = train['ua'][:, :12].values
    train_va = train['va'][:, :12].values
    train_label = label['nino'][:, 12:36].values

    train_ua = np.nan_to_num(train_ua)
    train_va = np.nan_to_num(train_va)
    train_t300 = np.nan_to_num(train_t300)
    train_sst = np.nan_to_num(train_sst)

    # SODA data    
    train2 = xr.open_dataset('tcdata/enso_round1_train_20210201/SODA_train.nc')
    label2 = xr.open_dataset('tcdata/enso_round1_train_20210201/SODA_label.nc')
    
    train_sst2 = train2['sst'][:, :12].values  # (100, 12, 24, 72)
    train_t3002 = train2['t300'][:, :12].values
    train_ua2 = train2['ua'][:, :12].values
    train_va2 = train2['va'][:, :12].values
    train_label2 = label2['nino'][:, 12:36].values

    print('Train samples: {}, Valid samples: {}'.format(len(train_label), len(train_label2)))

    dict_train = {
        'sst':train_sst,
        't300':train_t300,
        'ua':train_ua,
        'va': train_va,
        'label': train_label}
    dict_valid = {
        'sst':train_sst2,
        't300':train_t3002,
        'ua':train_ua2,
        'va': train_va2,
        'label': train_label2}
    train_dataset = EarthDataSet(dict_train)
    valid_dataset = EarthDataSet(dict_valid)
    return train_dataset, valid_dataset

In [4]:
class EarthDataSet(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data['sst'])

    def __getitem__(self, idx):   
        return (self.data['sst'][idx], self.data['t300'][idx], self.data['ua'][idx], self.data['va'][idx]), self.data['label'][idx]

In [63]:
class simpleSpatailTimeNN(nn.Module):
    def __init__(self, kernals=[3]):
        super(simpleSpatailTimeNN, self).__init__()
        self.conv1 = nn.ModuleList([nn.Conv2d(in_channels=12, out_channels=24, kernel_size=i) for i in kernals]) 
        self.conv2 = nn.ModuleList([nn.Conv2d(in_channels=12, out_channels=24, kernel_size=i) for i in kernals])
        self.conv3 = nn.ModuleList([nn.Conv2d(in_channels=12, out_channels=24, kernel_size=i) for i in kernals])
        self.conv4 = nn.ModuleList([nn.Conv2d(in_channels=12, out_channels=24, kernel_size=i) for i in kernals])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.batch_norm = nn.BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True)
        self.linear = nn.Linear(96, 24)

    def forward(self, sst, t300, ua, va):
        for conv1 in self.conv1:
            sst = conv1(sst)  # batch * 24 * (24 - 2) * (72 -2)
        for conv2 in self.conv2:
            t300 = conv2(t300)
        for conv3 in self.conv3:
            ua = conv3(ua)
        for conv4 in self.conv4:
            va = conv4(va)
        
        x = torch.cat([sst, t300, ua, va], dim=1) # batch * 96 * (24 - 2) * (72 -2)
        x = self.batch_norm(x)
        x = self.avgpool(x).squeeze(dim=-1).squeeze(dim=-1)
        x = self.linear(x)
        return x

In [6]:
def coreff(x, y):
    x_mean = np.mean(x)
    y_mean = np.mean(y)
    c1 = sum((x - x_mean) * (y - y_mean))
    c2 = sum((x - x_mean)**2) * sum((y - y_mean)**2)
    return c1/np.sqrt(c2)

def rmse(preds, y):
    r = np.sqrt(sum((preds - y)**2) / preds.shape[0])
    return r

def eval_score(preds, label):
    acskill_socre = 0
    rmse_score = 0
    a = [1.5]*4 + [2]*7 + [3]*7 + [4]*6
    for i in range(24):
        r = rmse(preds[:, i], label[:, i], ) # T时刻 (100,)
        cor = coreff(preds[:, i], label[:, i], )
    
        rmse_score += r
        acskill_socre += a[i] * np.log(i+1) * cor
    print("acskill_socre:{}, rmse_score:{}".format(2/3*acskill_socre, rmse_score))
    return 2/3 * acskill_socre - rmse_score

In [7]:
def train(num_epochs):
    best_score = 0
    for epoch in range(num_epochs):
        model.train()
        for step, ((sst, t300, ua, va), label) in enumerate(train_loader):                
            sst = sst.to(device).float()
            t300 = t300.to(device).float()
            ua = ua.to(device).float()
            va = va.to(device).float()
            optimizer.zero_grad()
            label = label.to(device).float()
            preds = model(sst, t300, ua, va)
            loss = loss_fn(preds, label)
            loss.backward()
            optimizer.step()
            if step%20 == 0:
                print('Step: {}, Train Loss: {}'.format(step, loss))

        model.eval()
        y_true, y_pred = [], []
        for step, ((sst, t300, ua, va), label) in enumerate(valid_loader):
            sst = sst.to(device).float()
            t300 = t300.to(device).float()
            ua = ua.to(device).float()
            va = va.to(device).float()
            label = label.to(device).float()
            preds = model(sst, t300, ua, va)

            y_pred.append(preds)
            y_true.append(label)

        y_true = torch.cat(y_true, axis=0)
        y_pred = torch.cat(y_pred, axis=0)
        score = eval_score(y_true.cpu().detach().numpy(), y_pred.cpu().detach().numpy())
        print('Epoch: {}, Valid Score {}'.format(epoch+1,score))

        if score > best_score:
            torch.save(model.state_dict(), './models/basemodel_epoch_{}.pt'.format(epoch))
            torch.save(model.state_dict(), './models/basemodel_best.pt')
            print('Model saved successfully')
            best_score = score
        print()

In [8]:
set_seed()
train_dataset, valid_dataset = load_data()      
train_loader = DataLoader(train_dataset, batch_size=32)
valid_loader = DataLoader(valid_dataset, batch_size=32)

Train samples: 4645, Valid samples: 100


In [64]:
model = simpleSpatailTimeNN()
device = 'cuda' if torch.cuda.is_available() else 'cpu'   
optimizer = torch.optim.Adam(model.parameters(), lr=8e-5)
loss_fn = nn.MSELoss()   

model = model.to(device)
loss_fn = loss_fn.to(device)

In [65]:
model

simpleSpatailTimeNN(
  (conv1): ModuleList(
    (0): Conv2d(12, 24, kernel_size=(3, 3), stride=(1, 1))
  )
  (conv2): ModuleList(
    (0): Conv2d(12, 24, kernel_size=(3, 3), stride=(1, 1))
  )
  (conv3): ModuleList(
    (0): Conv2d(12, 24, kernel_size=(3, 3), stride=(1, 1))
  )
  (conv4): ModuleList(
    (0): Conv2d(12, 24, kernel_size=(3, 3), stride=(1, 1))
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (batch_norm): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (linear): Linear(in_features=96, out_features=24, bias=True)
)

In [66]:
train(num_epochs=30)

Step: 0, Train Loss: 1.3815606832504272
Step: 20, Train Loss: 0.7326895594596863
Step: 40, Train Loss: 1.0478894710540771
Step: 60, Train Loss: 0.9542536735534668
Step: 80, Train Loss: 0.4335620105266571
Step: 100, Train Loss: 0.24059800803661346
Step: 120, Train Loss: 0.39798635244369507
Step: 140, Train Loss: 0.5819956064224243
acskill_socre:1.737307030518238, rmse_score:18.623725658157543
Epoch: 1, Valid Score -16.886418627639305

Step: 0, Train Loss: 1.3625991344451904
Step: 20, Train Loss: 0.7136222720146179
Step: 40, Train Loss: 1.0212348699569702
Step: 60, Train Loss: 0.9349280595779419
Step: 80, Train Loss: 0.4249457120895386
Step: 100, Train Loss: 0.2420976161956787
Step: 120, Train Loss: 0.3882640302181244
Step: 140, Train Loss: 0.5799397826194763
acskill_socre:5.826650072549178, rmse_score:18.57177661296857
Epoch: 2, Valid Score -12.745126540419392

Step: 0, Train Loss: 1.346677541732788
Step: 20, Train Loss: 0.6927112340927124
Step: 40, Train Loss: 0.9936569929122925
Step: 

In [94]:
model.load_state_dict(torch.load('models/basemodel_epoch_5.pt'))

<All keys matched successfully>

In [7]:
test_path = './tcdata/enso_round1_test_20210201/'

### load test data
files = os.listdir(test_path)
test_feas_dict = {}
for file in files:
    test_feas_dict[file] = np.load(test_path + file)

In [8]:
### 2. predict
test_predicts_dict = {}
for file_name,val in test_feas_dict.items():
    SST = torch.tensor(val[:,:,:,0]).unsqueeze(0).to(device).float()
    T300 = torch.tensor(val[:,:,:,1]).unsqueeze(0).to(device).float()
    Ua = torch.tensor(val[:,:,:,2]).unsqueeze(0).to(device).float()
    Va = torch.tensor(val[:,:,:,3]).unsqueeze(0).to(device).float()
    test_predicts_dict[file_name] = model(SST, T300, Ua, Va).view(-1).detach().cpu().numpy()
#     test_predicts_dict[file_name] = model.predict(val.reshape([-1,12])[0,:])

In [9]:
### 3. save results
if os.path.exists('./result/'):  
    shutil.rmtree('./result/', ignore_errors=True)  
os.makedirs('./result/')
for file_name, val in test_predicts_dict.items(): 
    np.save('./result/' + file_name, val)

In [44]:
def make_zip(res_dir='./result', output_dir='result.zip'):  
    z = zipfile.ZipFile(output_dir, 'w')  
    for file in os.listdir(res_dir):  
        if '.npy' not in file:
            continue
        z.write(res_dir + os.sep + file)  
    z.close()

In [45]:
make_zip()

./result/test_0144-01-12.npy
