In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from tqdm import tqdm
from glob import glob
import os, sys
import numpy as np
from configs import *
from utils import *
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import numpy as np
from dataset import *
torch.manual_seed(0)
np.random.seed(0)

In [3]:
ESIZE = 64
EPS = 120
BSIZE = 32
HSIZE = 64
LAG = 12
SROUTE = SAMPLE_ROUTES[0]
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=1)

In [4]:
dset = SingleStop(SROUTE, 30, 'train', 32, lag=LAG)#.generator()
evalset = SingleStop(SROUTE, 30, 'test', 32, lag=LAG)#.generator()

Locals dataset: train
 [*] Loaded routes: 1 (0.22s)
 [*] Has trainable inds: 262262
 [*] Subset train: 212106
 [*] Subset in Stop-30: 6079
Locals dataset: test
 [*] Loaded routes: 1 (0.24s)
 [*] Has trainable inds: 262262
 [*] Subset test: 50156
 [*] Subset in Stop-30: 1549


In [5]:
class Node:
    def __init__(self, value):
        self._v = value.clone().to(device) # label
        self.v = value.to(device)
        self.h = torch.zeros(BSIZE, HSIZE).to(device)
        self.ns = [] # neighbors
    
    def show(self):
        pnt = self
        while len(pnt.ns):
            print(pnt.v.size(), end=' ')
            pnt = pnt.ns[0]
        print()
        
    def ln(self):
        l = 0
        pnt = self
        while len(pnt.ns):
            pnt = pnt.ns[0]
            l += 1
        return l
        
def routeToGraph(batch):
    at_time = [] # t-h ... t
    for time in torch.split(batch, 1, dim=1):
        root, pnt = None, None
        for stop in torch.split(time.squeeze(1), 1, dim=1):
            if root is None:
                root = Node(stop)
                pnt = root
            else:
                nd = Node(stop)
                pnt.ns.append(nd)
                pnt = nd
        at_time.append(root)
    return at_time

def inst_tree(struct, nodes):
    ls = []
    for ent in nodes:
        inst = struct(ent) #.to(device)
        inst.device = device
        kobj = dict(
            op=inst,
            ns=inst_tree(struct, ent.ns)
        )
        ls.append(kobj)
    return ls    

In [6]:
kernels = None
upops = None

In [7]:
criterion = nn.MSELoss().to(device)
# opt = optim.SGD(self.parameters(), lr=lr)
# sch = optim.lr_scheduler.StepLR(opt, step_size=15, gamma=0.2)

In [8]:
from models.Kernel import Kernel, Update

def zip_op(t1, t2, op):
    ls = [(t1, t2)]
    while len(ls):
        n1, n2 = ls[0]
        ls = ls[1:]
        op(n1, n2)
        for c1, c2 in zip(n1['ns'], n2.ns):
            ls.append((c1, c2))

def message(kernels, graph_t):
    zip_op(kernels, graph_t, op=lambda kern, node: kern['op'](node))
    
def update(kernels, graph_t):
    zip_op(upops, graph_t, op=lambda up, node: up['op'](node))
    
def gather_predictions(_node, node):
    # end nodes do not hold convolved results, so they are ignored
    ls = [(_node._v, node.v)] if len(node.ns) else []
    for _nb, nb in zip(_node.ns, node.ns):
        ls += gather_predictions(_nb, nb)
    return ls

for ii in range(0, len(dset), 32):
    batch = np.array([dset[ii+jj] for jj in range(32)])
    batch = torch.Tensor(batch)
    batch = routeToGraph(batch)

    states = batch[0] # will be used to hold iterated values
    if kernels is None:
        kernels = inst_tree(
            lambda node: Kernel(insize=1 + len(node.ns), hsize=HSIZE).to(device), 
            [states])[0]
        upops = inst_tree(
            lambda _: Update(hsize=HSIZE).to(device), 
            [states])[0]

    # initial iteration
    _ = message(kernels, states)
    _ = update(upops, states)
    
    for ti, graph_t in enumerate(batch[1:]):
        # compute loss
        _ys, ys = zip(*gather_predictions(graph_t, states))
        _ys, ys = torch.cat(_ys, dim=1), torch.cat(ys, dim=1)
        loss = criterion(ys, _ys) # TODO: ignore loss at edges
        print(loss.item(), _ys.size())
        
        # iterate
        _ = message(kernels, states)
        _ = update(upops, states)

    # TODO: apply loss to all newvals
    break

13.579473495483398 torch.Size([32, 9])
12.546544075012207 torch.Size([32, 9])
16.351491928100586 torch.Size([32, 9])
16.085342407226562 torch.Size([32, 9])
18.43267822265625 torch.Size([32, 9])
14.53554916381836 torch.Size([32, 9])
12.086894989013672 torch.Size([32, 9])
13.768434524536133 torch.Size([32, 9])
16.079240798950195 torch.Size([32, 9])
13.26963996887207 torch.Size([32, 9])
16.51372528076172 torch.Size([32, 9])
