In [None]:
%load_ext autoreload
%autoreload 2
import matplotlib
import matplotlib.pyplot as plt

%matplotlib inline

import yaml
import glob
import torch
import numpy as np
from torch import nn
from torchvision.models import ResNet
from torch.utils.data import DataLoader,Dataset

import logging
from importlib import reload  # Not needed in Python 2
reload(logging)

torch.backends.cudnn.enabled = False
torch.backends.cudnn.deterministic = True
device = torch.device("cuda"  if torch.cuda.is_available() else "cpu")

logging.basicConfig(level=logging.INFO,#控制台打印的日志级别
                    filename='logging.txt',
                    filemode='a',##模式，有w和a，w就是写模式，每次都会重新写日志，覆盖之前的日志
                    #a是追加模式，默认如果不写的话，就是追加模式
                    format=
                    '%(asctime)s : %(message)s',
                    )



from dataset.ASdataset import AS_Data
from dataset.ASdataset_obs_train_input import AS_Data_obs


In [None]:
with open('config/cfg.yaml','r') as f:
    cfg = yaml.load(f)

cfg = {**cfg['step1'],**cfg['share_cfg']}
T = cfg['T']
pollution = cfg['pollution']
batch_size = cfg['batch_size']

print('train data is loading ')
Data = AS_Data(cfg['data_path'],left = cfg['train']['left'],right = cfg['train']['right'],window = T,pollution = pollution)
trainloader = DataLoader(Data,batch_size=batch_size,shuffle=True)
print(len(Data))

print('test data is loading ')
test_Data = AS_Data(cfg['data_path'],left = cfg['test']['left'],right = cfg['test']['right'],window = T,pollution = pollution)
testloader = DataLoader(test_Data,batch_size=batch_size,shuffle=True)
print(len(test_Data))

In [None]:
from model.res_model_LSTM import res8
from model.unet_model_LSTM import UNet

# test_model = res8(51+34,27,inplanes=64,layers = [2],T=T,pre_dim = 2) #+5*16
# name = 'res'
test_model = UNet(cfg['meteorological_dim']+cfg['emission_dim'],cfg['grid_dim'],T=T,bilinear=False,pre_dim = len(pollution)) #+80
name = cfg['name']
# test_model.load_state_dict(torch.load('model_save/o3_best_unet2_1month_65_epoch.t'))

test_model.to(device)
criterion = torch.nn.L1Loss()
optimizer = torch.optim.Adam(test_model.parameters(),lr=1e-3)


In [None]:
def score(model,loader,criterion= nn.L1Loss(),percent = False):
    model.eval()
    ls = []
    for idx,i in enumerate(loader):
        with torch.no_grad():
            input,grid,yt_1,label = i
            input,grid,yt_1,label = input.to(device),grid.to(device),yt_1.to(device),label.to(device)
            y_pred = model(input,grid,yt_1)
            
            cur_loss = []
            for j in range(label.shape[1]):
                if percent:
                    for esp in [0.1,1,4,8,12,16]:
                        loss = torch.mean(torch.abs(y_pred[:,j]-label[:,j])/(label[:,j]+esp))
                        cur_loss.append(loss.cpu().data)
                else:
                    loss = criterion(y_pred[:,j],label[:,j])
                    cur_loss.append(loss.cpu().data)

            ls.append(cur_loss)
            
    return np.mean(np.array(ls),axis = 0)

In [None]:
best_score = 1000
early_stop = 15
early_cnt = 0
for epoch in range(21):
    logging.info('-----------{}-----------'.format(epoch))
    ls = []
    
    test_model.train()
    for idx,i in enumerate(trainloader):
        input,grid,yt_1,label = i
        input,grid,yt_1,label = input.to(device),grid.to(device),yt_1.to(device),label.to(device)
        y_pred = test_model(input,grid,yt_1)
        
        assert y_pred.shape == label.shape
        
        optimizer.zero_grad()
        
        loss = criterion(y_pred,label)
        loss.backward()
        optimizer.step()
        ls.append(loss.cpu().data)
        if len(ls)%400==0:
            logging.info('epoch {} cur loss {}'.format(epoch,np.mean(ls)))
    
    logging.info('epoch {} cur loss {}'.format(epoch,np.mean(ls)))
    test_score_L1 = score(test_model,testloader,criterion = nn.L1Loss()) 
    logging.info('-------------cur test loss L1:  {}'.format(','.join([str(s) for s in test_score_L1])))
    
#     if epoch%5 == 0:
#         torch.save(test_model.cpu().state_dict(),'model_save/{}_{}_epoch.t'.format(name,epoch))
#         test_model.to(device)
    if np.sum(test_score_L1)<best_score:
        early_cnt = 0
        best_score = np.sum(test_score_L1)
        torch.save(test_model.cpu().state_dict(),name)
        test_model.to(device)
    else:
        early_cnt += 1
        if early_cnt>=early_stop:
            break

In [None]:
test_model.load_state_dict(torch.load(name))

# test_model = res8(51+34+16,27,[3],T=48)

test_model = test_model.to(device)

In [None]:
import matplotlib
import matplotlib.pyplot as plt

%matplotlib inline

def image_show(im,title = '',base=80):
    image = im.copy()
    image[:5,:5] = base
    plt.title(title)
    plt.imshow(image[::-1],cmap=plt.cm.hot_r)
    plt.colorbar()
    

In [None]:
plt.rcParams['figure.figsize'] = (16,7.0)

test_model.eval()
air_idx = 0 #pm25

for idx,i in enumerate(testloader):
    with torch.no_grad():
        input,grid,yt_1,label = i
        input = torch.squeeze(input,1)
        input,grid,yt_1,label = input.to(device),grid.to(device),yt_1.to(device),label.to(device)
        y_pred = test_model(input,grid,yt_1)
        
        label = label.cpu().numpy()
        y_pred = y_pred.cpu().numpy()
        for b_idx in range(len(label)):
            
            plt.subplot(1,3,1)
            image1 = label[b_idx,air_idx]
            max_label = np.max(image1)
            image_show(image1,'label',base=max_label)
            
            plt.subplot(1,3,2)
            image2 = y_pred[b_idx,air_idx]
            image_show(image2,'pred',base=max_label)
            
            plt.subplot(1,3,3)
            image3 = np.abs(image2-image1)
            image_show(image3,'diff',base=max_label)
            plt.show()
        if idx>5:break
