In [35]:
ckpt_path="./ckpt/model_mae_78.38278.pt"  # my best ckpt

In [36]:
from pathlib import Path
import re
import os
import numpy as np
import pandas as pd
import itertools
import matplotlib.pyplot as plt
import time
import pickle
import zipfile
from datetime import *

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import *
from torch.nn.utils.rnn import *
from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset

import torch.optim as optim
from torch.optim.lr_scheduler import *

from datetime import datetime
from tqdm.notebook import tqdm, trange
from IPython.core.debugger import set_trace
from IPython.display import display, HTML, clear_output

from sklearn.preprocessing import LabelEncoder, OrdinalEncoder, StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer

# from sklearn_pandas import DataFrameMapper
import sys
import wandb

import pytorch_lightning as pl
from pytorch_lightning.utilities.seed import seed_everything

import msgpack
from box import Box
import math
import yaml
import warnings
warnings.filterwarnings('ignore')

seed_everything(55)


from train_utils import train, valid, dict2device
# from ds_utils import DsLoader, GisDS, get_train_dl, get_valid_dl
from common_utils import *


paths = yaml.safe_load(open('./config.yml').read())

data_dir = Path(paths['data_dir'])
pkl_dir = Path(paths['pkl_dir'])
msg_dir = Path(paths['msg_dir'])


link_counter = load_pickle(pkl_dir/'link_freq.pkl')
len(link_counter)

links_30 = [item[0] for item in list(filter(lambda x: x[1]>30, link_counter.most_common()))]
len(links_30)


link2id = {}
for i, link in enumerate(links_30):
    link2id[link]=i+1 # 从1开始


def clean_linkids(link_ids, link2id):
    link_ids = [ str(int(link)) for link in link_ids]
    
    emb_ids = [link2id[link] if link in link2id else 0 for link in link_ids]
    
    return emb_ids


def batch2tensor(batch, name, log_trans=False, long_tensor=False):
    
    if long_tensor == True:
        x = torch.LongTensor([int(item[name]) for item in batch])
    else:
        x = torch.FloatTensor([item[name] for item in batch])
        
    if log_trans==True:
        x = torch.log(x)
        
    return x

# 都是log后的
dist_min, dist_max, dist_mean, dist_std = (2.4535277755531824, 11.879284286444815, 8.325948361544423, 0.6799133140855674)
eta_min, eta_max, eta_mean, eta_std = (2.3978952727983707, 9.371353167823885, 6.553886963677842, 0.5905307292899195)
simple_eat_min, simple_eat_max, simple_eat_mean, simple_eat_std = (0.6931471805599453, 9.320180837655714, 6.453206241137908, 0.5758803681400783)

high_temp_mean, high_temp_std = (31.84375, 1.6975971069426339)
low_temp_mean, low_temp_std = (26.46875, 0.9348922063532245)


# 没有log处理
link_time_min, link_time_max, link_time_mean, link_time_std = (0.0, 2949.12, 6.843469259130468, 8.63917700058627)


driver2id = load_pickle(pkl_dir/"driver2id_dct.pkl")

Global seed set to 55


In [37]:
def collate(batch, date=None):
    order_id = [item['order_id'] for item in batch]
    
    # numerical
    eta = (batch2tensor(batch, 'eta', log_trans=True) - eta_mean)/eta_std
    dist = (batch2tensor(batch, 'dist', log_trans=True) - dist_mean)/dist_std
    simple_eta = (batch2tensor(batch, 'simple_eta', log_trans=True) - simple_eat_mean)/simple_eat_std
    
    low_temp = (batch2tensor(batch, 'lowtemp') - low_temp_mean)/low_temp_std
    high_temp = (batch2tensor(batch, 'hightemp') - high_temp_mean)/high_temp_std
    
    driver_id = torch.LongTensor([driver2id[item['driver_id']] for item in batch])
    
    slice_id = batch2tensor(batch, 'slice_id', long_tensor=True)
    
    hour = (slice_id*5)//60
    
    weekday = batch2tensor(batch, 'weekday', long_tensor=True)
    
    weather = torch.LongTensor([int(item['weather']) for item in batch])
    
    # link_cross
    link_cross_start = [torch.LongTensor(clean_linkids(item['link_id']+item['cross_start'], link2id)) for item in batch]   
    link_cross_start = pad_sequence(link_cross_start, batch_first=True)
    
    link_cross_end = [torch.LongTensor(clean_linkids(item['link_id']+item['cross_end'], link2id)) for item in batch]
    link_cross_end = pad_sequence(link_cross_end, batch_first=True)
    
    link_cross_ratio = [torch.FloatTensor(list(item['link_ratio']) + [1]*len(item['cross_start'])) for item in batch]
    link_cross_ratio = pad_sequence(link_cross_ratio, batch_first=True)
    
    link_cross_current_status = [torch.FloatTensor(list(item['link_current_status']) + [1]*len(item['cross_start'])) for item in batch]
    link_cross_current_status = pad_sequence(link_cross_current_status, batch_first=True)/5
    
    link_cross_len = torch.FloatTensor([ len(item['link_id']) + len(item['cross_start']) for item in batch ])
    
    link_cross_time = [torch.FloatTensor(item['link_time']+item['cross_time']) for item in batch]
    link_cross_time = (pad_sequence(link_cross_time, batch_first=True)-link_time_min)/(link_time_max-link_time_min)

    link_time_total = torch.FloatTensor([sum(item['link_time']) for item in batch])/1000
    
    cross_time_total = torch.FloatTensor([sum(item['cross_time']) for item in batch])/100
    
    link_len = torch.FloatTensor([len(item['link_time']) for item in batch])/1000
    
    cross_len = torch.FloatTensor([len(item['cross_time']) for item in batch])/10
    
    
    return {
        "order_id": order_id,
        "dist": dist,
        "simple_eta": simple_eta,
        "driver_id": driver_id,
        "slice_id": slice_id,
        "hour": hour,
        "weekday": weekday,
        "weather": weather,
        "low_temp": low_temp,
        "high_temp": high_temp,
        
        "link_cross_start": link_cross_start,
        "link_cross_end": link_cross_end,
        "link_cross_time": link_cross_time,
        "link_cross_len": link_cross_len,
        "link_cross_current_status": link_cross_current_status,
        "link_cross_ratio": link_cross_ratio,

        # 第一波特征
        "link_time_total": link_time_total,
        "cross_time_total": cross_time_total,
        "link_len": link_len,
        "cross_len": cross_len,
        
    }, eta

In [76]:
class CombineModel(nn.Module):

    def __init__(self):
        super().__init__()
        
        # 时间
        slice_num, slice_dim = 288, 20
        driver_num, driver_dim = len(driver2id), 20
        
        weekday_num, weekday_dim = 7, 3
        weather_num, weather_dim = 5, 3
        
        link_emb_dim = 20 # 目前20最好
        link_time_dim = link_ratio_dim = link_current_status_dim = 1

        self.link_emb = nn.Embedding(len(link2id)+1, link_emb_dim)
        
        self.slice_emb = nn.Embedding(slice_num, slice_dim)
        
        
        self.driver_emb = nn.Embedding(driver_num, driver_dim)
        
        self.weekday_emb = nn.Embedding(weekday_num, weekday_dim)
        
        self.weather_emb = nn.Embedding(weather_num, weather_dim)
        
        # link_emb 128 + link_time 1 + link_current_status 1 + link_ratio_dim 1
        lstm_input_dim = link_emb_dim + 1 + 1 + 1
        lstm_output_dim = 128
        self.lstm = LSTM(lstm_input_dim,
                         lstm_output_dim, 
                         batch_first=True,
                        )
        # ckpt
        linear_dim = 175

        self.linear = Sequential(Linear(linear_dim,
                                        256),
                                 LeakyReLU(inplace=True),
                                 Linear(256, 128),
                                 LeakyReLU(inplace=True),
                                 Linear(128, 1)
                                )

    def forward(self, x):
        
        x_link_cross_start = self.link_emb(x['link_cross_start'])
        x_link_cross_end = self.link_emb(x['link_cross_end'])
        x_link_cross = (x_link_cross_start + x_link_cross_end)/2
        
        x_link_cross_time = x['link_cross_time']
        x_link_cross_current_status = x['link_cross_current_status']
        
        x_link_cross_ratio = x['link_cross_ratio']
        
        x_lstm = torch.cat([x_link_cross,
                            x_link_cross_time.unsqueeze(-1), 
                            x_link_cross_current_status.unsqueeze(-1),
                            x_link_cross_ratio.unsqueeze(-1),
                           ],
                           -1)
        
        packed = pack_padded_sequence(x_lstm,
                                      x['link_cross_len'].cpu(), 
                                      batch_first=True, 
                                      enforce_sorted=False)
        
        lstm_output, (ht, ct) = self.lstm(packed)
        ht = ht.reshape(len(x['dist']), -1)
        
        # slice
        x_slice = self.slice_emb(x['slice_id'])
        
        # numerical
        x_num = torch.cat([x['simple_eta'].unsqueeze(-1),
                           x['dist'].unsqueeze(-1),
                           x['low_temp'].unsqueeze(-1),
                           x['high_temp'].unsqueeze(-1),

                          ], 
                           axis=-1) # 2
        
        # driver
        x_driver = self.driver_emb(x['driver_id'])
        
        # weekday
        x_weekday = self.weekday_emb(x['weekday'])
        
        x_comb = torch.cat([ht, x_num, x_slice, x_driver, x_weekday], axis=-1)

        res = self.linear(x_comb)

        return res.reshape([-1,])

In [79]:
class GisDS(Dataset):
    def __init__(self, file_path):
        try:
            with open(file_path, 'rb') as f:
                data = f.read()
            self.data = msgpack.unpackb(data, use_list=False)
        except:
            self.data = []
        
    def __getitem__(self, idx):
        return self.data[idx]
    
    def __len__(self):
        return len(self.data)

    
def get_train_dl(train_ds, num_workers=0, pin_memory=False, 
                 collate_fn=None,
                 batch_size=3):
    
    train_dl = DataLoader(train_ds,
                          collate_fn=collate_fn, 
                          batch_size=batch_size,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          drop_last=True,
                          shuffle=True)
    return train_dl

def get_valid_dl(valid_ds, num_workers=0, pin_memory=False, 
                 collate_fn=None,
                 batch_size=3):
    
    val_dl = DataLoader(valid_ds, 
                        collate_fn=collate_fn, 
                        batch_size=batch_size,
                        num_workers=num_workers,
                        pin_memory=pin_memory,
                        drop_last=True,
                        shuffle=False)
    return val_dl

class DsLoader():
    def __init__(self, cache_all=False):
        self.ds_dct = {}
        self.cache_all = cache_all
        
    def get_train_ds(self, day):
        if self.cache_all==False:
            # del prev day
            pre_key = f"{day-1:02}"
            if pre_key in self.ds_dct:
                del self.ds_dct[pre_key]
            
        key = f"{day:02}"
        if key not in self.ds_dct:
            ds = self.load_ds(day)
            self.ds_dct[key] = ds
            self.preload_next(day+1)
            
            return ds
        else:
            ds = self.ds_dct[key]
            self.preload_next(day+1)
            return ds
        
    def preload_next(self, day):
        threading.Thread(
                    target=self.preload_ds, args=(day,), 
                    daemon=True
            ).start()
    
    def preload_ds(self, day):
        key = f"{day:02}"
        if key not in self.ds_dct:
            self.ds_dct[key]=GisDS(msg_dir/f"202008{day:02}.msgpack")
        
    def load_ds(self, day):
        return GisDS(msg_dir/f"202008{day:02}.msgpack")

In [80]:
args = Box({
    "batch_size": 512,
    "num_workers": 20,
    "pin_memory": True,
})

In [81]:
device = torch.device("cuda:0")

model = CombineModel().to(device)

model.load_state_dict(torch.load(ckpt_path))

<All keys matched successfully>

In [82]:
%%time
test_ds = GisDS(msg_dir/"20200901_test.msgpack")
len(test_ds)

CPU times: user 7.04 s, sys: 3.67 s, total: 10.7 s
Wall time: 10.7 s


288076

In [83]:
def get_test_dl(test_ds, collate_fn=None, num_workers=0, pin_memory=False, batch_size=3):
    
    test_dl = DataLoader(test_ds, 
                         collate_fn=collate_fn, 
                         batch_size=batch_size,
                         num_workers=num_workers,
                         pin_memory=pin_memory,
                         drop_last=False,
                         shuffle=False)
    return test_dl

test_dl = get_test_dl(test_ds, 
                      num_workers=args.num_workers, 
                      pin_memory=args.pin_memory, 
                      batch_size=args.batch_size,
                      collate_fn=lambda x: collate(x),)

res = []

for batch in tqdm(test_dl):

    pred = torch.exp(
        model(dict2device(batch[0], device))*eta_std + eta_mean
    )
    
    res+=pred.tolist()

  0%|          | 0/563 [00:00<?, ?it/s]

In [84]:
submit_df = pd.read_csv(data_dir/"sample_submission.csv")

submit_df.result = res

dt_string = datetime.now().strftime("%Y_%m_%d_%H_%M")

os.makedirs("./predict/", exist_ok=True)
 
csv_fname = f"./predict/pred_{dt_string}.csv"
# zip_fname = f"./predict/pred_{dt_string}.zip"

submit_df.to_csv(csv_fname, index=False)

# with zipfile.ZipFile(zip_fname, 'w') as zf:
#     zf.write(csv_fname, compress_type=zipfile.ZIP_DEFLATED)