In [1]:
%matplotlib inline

import random
from functools import partial
import time

import torch
import torch.nn as nn
import numpy as np
import networkx as nx
import pandas as pd
import matplotlib.pyplot as plt

from dqnroute.utils import *
from dqnroute.networks import *

%load_ext autoreload
%autoreload 2

In [3]:
data = pd.read_csv('../src/pretrain_data.csv', index_col=0)

In [4]:
data.head()

Unnamed: 0,dst,src,pkg_id,nbr,amatrix_0,amatrix_1,amatrix_2,amatrix_3,amatrix_4,amatrix_5,...,amatrix_91,amatrix_92,amatrix_93,amatrix_94,amatrix_95,amatrix_96,amatrix_97,amatrix_98,amatrix_99,estim
0,9.0,7.0,0.0,6.0,0.0,15.0,15.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,15.0,0.0,0.0,15.0,0.0,-60.0
1,9.0,7.0,0.0,3.0,0.0,15.0,15.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,15.0,0.0,0.0,15.0,0.0,-60.0
2,9.0,7.0,0.0,5.0,0.0,15.0,15.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,15.0,0.0,0.0,15.0,0.0,-30.0
3,9.0,5.0,0.0,4.0,0.0,15.0,15.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,15.0,0.0,0.0,15.0,0.0,-45.0
4,9.0,5.0,0.0,7.0,0.0,15.0,15.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,15.0,0.0,0.0,15.0,0.0,-45.0


In [6]:
data.shape

(165293, 105)

In [7]:
def shuffle(df):
    return df.reindex(np.random.permutation(df.index))

In [8]:
class CachedEmbs:
    def __init__(self, InnerEmbs, dim):
        self.InnerEmbs = InnerEmbs
        self.dim = dim
        
        self.cache = {}
        
    def get_hash(self, amatrix):
        return hash(str(amatrix))
        
    def fit(self, amatrix):
        h = self.get_hash(amatrix)
        if h not in self.cache:
            embs = self.InnerEmbs(self.dim)
            embs.fit(amatrix)
            self.cache[h] = embs
    
    def encode(self, amatrix, nodes):
        h = self.get_hash(amatrix)
        return self.cache[h].encode(nodes)

In [14]:
# Итерация по батчам
def qnetwork_batches(addit_inputs, train_df, batch_size, embs):    
    amatrix_cols = [col for col in train_df if col.startswith('amatrix')]

    addit_cols = []
    for inp in addit_inputs:
        if inp['name'] == 'amatrix':
            addit_cols.extend(amatrix_cols)

    for start, end in make_batches(len(train_df), batch_size):
        #st = time.time()

        batch_df = train_df[start:end]
        
        srcs, dsts, nbrs, addits = [], [], [], []
        
        for i in range(len(batch_df)):
            row = batch_df.iloc[i]
            amatrix = row[amatrix_cols].values
            embs.fit(amatrix)
            
            srcs.append(embs.encode(amatrix, row['src']))
            dsts.append(embs.encode(amatrix, row['dst']))
            nbrs.append(embs.encode(amatrix, row['nbr']))
            
            addit = row[addit_cols].values
            if len(addit):
                addits.append(addit)
        
        prep_batch = [
            np.array(srcs),
            np.array(dsts),
            np.array(nbrs)
        ]
        
        if len(addits):
            prep_batch.append(np.array(addits))
        
        targets = torch.tensor(batch_df['estim'].values, dtype=torch.float)
        
        #print(time.time() - st)
        #print(end)
        
        yield prep_batch, targets


# Эпоха с оптимизацией
def qnetwork_pretrain_epoch(model, optimizer, data, embs):
    loss_fn = nn.MSELoss()
    
    for batch, target in qnetwork_batches(model.addit_inputs, data, 64, embs):
        # Обнулить градиент, который накапливается во время обратного прохода
        optimizer.zero_grad()
        
        pred = model(*batch)
        loss = loss_fn(pred, target.unsqueeze(1))
        loss.backward()
        
        # Обновить параметры модели
        optimizer.step()
        
        yield float(loss)


# Итерация по эпохам
def qnetwork_pretrain(model,
                      data,
                      optim_name,
                      epoch_num,
                      embs,
                      need_save=True):
    optimizer = optim_class(optim_name)(model.parameters(), lr=0.001)
    
    epoch_losses = []
    for _ in range(epoch_num):
        loss_sum = 0
        loss_num = 0
        for loss in qnetwork_pretrain_epoch(model, optimizer, data, embs):
            loss_sum += loss
            loss_num += 1
            
        batch_loss = loss_sum / loss_num
        print(batch_loss)
        epoch_losses.append(batch_loss)
        
    if need_save:
        model.save()
    
    return epoch_losses

In [15]:
plt.rc('font', size=14)

def plot_losses(losses_dict,
                name,
                figsize=(13, 7),
                title=None):
    fig = plt.figure(figsize=figsize)
    
    for label, losses in losses_dict.items():
        X = list(range(1, len(losses)+1))
        plt.plot(X, losses, label=label)
    
    plt.xlabel('Epoch')
    plt.ylabel('MSE')
    plt.xticks(X)
    plt.grid()
    plt.legend()
   
    if title is not None:
        plt.title(title)
        
    if name is not None:
        fig.savefig('../img/' + name, bbox_inches='tight')
    
    plt.show()

In [16]:
set_random_seed(40)

QNetworkAmatrix = partial(QNetwork, addit_inputs=[{'name': 'amatrix'}])

nodes_num = 10

## One-hot + adjacency matrix

In [17]:
model_am = QNetworkAmatrix(nodes_num,
                           [64, 64],
                           'relu',
                           {'name': 'oh'})

oh_embs = CachedEmbs(OHNodeEnc, dim=nodes_num)

In [18]:
losses_am = qnetwork_pretrain(model_am,
                              shuffle(data),
                              'rmsprop',
                              10,
                              oh_embs)

KeyboardInterrupt: 

## Laplacian eigenmaps only

In [19]:
model_le = QNetwork(nodes_num,
                    [64, 64],
                    'relu',
                    {'name': 'le', 'dim': 4})

le_embs = CachedEmbs(LENodeEnc, dim=4)

In [20]:
data_full_graph = data[data['pkg_id'] < 5000]
data_full_graph.head()

Unnamed: 0,dst,src,pkg_id,nbr,amatrix_0,amatrix_1,amatrix_2,amatrix_3,amatrix_4,amatrix_5,...,amatrix_91,amatrix_92,amatrix_93,amatrix_94,amatrix_95,amatrix_96,amatrix_97,amatrix_98,amatrix_99,estim
0,9.0,7.0,0.0,6.0,0.0,15.0,15.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,15.0,0.0,0.0,15.0,0.0,-60.0
1,9.0,7.0,0.0,3.0,0.0,15.0,15.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,15.0,0.0,0.0,15.0,0.0,-60.0
2,9.0,7.0,0.0,5.0,0.0,15.0,15.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,15.0,0.0,0.0,15.0,0.0,-30.0
3,9.0,5.0,0.0,4.0,0.0,15.0,15.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,15.0,0.0,0.0,15.0,0.0,-45.0
4,9.0,5.0,0.0,7.0,0.0,15.0,15.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,15.0,0.0,0.0,15.0,0.0,-45.0


In [21]:
print(model_le.label)
losses_le = qnetwork_pretrain(model_le,
                              shuffle(data),
                              'rmsprop',
                              10,
                              le_embs)

qnetwork_8_64-64_relu_le-4


KeyboardInterrupt: 

In [None]:
plot_losses({
    'One-hot + AM': losses_am,
    'LE': losses_le
}, #'foobar', title='Big dataset')