In [None]:
import numpy as np
import torch as th

device = th.device('cuda' if th.cuda.is_available() else 'cpu')
# device = 'cpu'
print('device:', device)

saved_models = {}

In [None]:
# import os

# from torch.utils.data import Dataset, DataLoader, random_split
# from collections import OrderedDict

# from typing import List

# class MovementDataSet(Dataset):
#   def __init__(self, data_path: str):
#     self.length = -1
#     def load_data(type: str):
#       dir_path: str = os.path.join(data_path, type) # ex: data/inputs
#       data = OrderedDict()
#       for f in os.listdir(dir_path):
#         t = th.tensor(np.load(os.path.join(dir_path, f)).astype(np.float32))
#         if self.length != -1 and len(t) != self.length:
#           raise Exception('Length mismatch. Expected: {}. Found: {}'.format(self.length, len(t)))
#         data[f.replace('.npy', '')] = t
        
#       return data
#     self.inputs = load_data('input')
#     self.outputs = load_data('output')

#   def __len__(self):
#     return len(self.inputs[0])

#   def __getitem__(self, idx):
#     def get_line(data):
#       return {k: v[idx] for k,v in data.items()}
#     return get_line(self.inputs), get_line(self.outputs)

# move_dataset = MovementDataSet('data/only_side_moves')
# print(move_dataset[0])


In [None]:
a = set([1,2,3,4])
a.difference([1,2])

In [None]:
import sys

import os

from torch.utils.data import Dataset, random_split
from collections import OrderedDict

from typing import Tuple, Optional


def load_tensor(*segments):
  t = th.tensor(np.load(os.path.join(*segments)).astype(np.float32))
  return t


def get_slice(data_dict: dict[str, th.Tensor], idx):
  return {k: v[idx] for k,v in data_dict.items()}

def get_avg(data_dict: dict[str, th.Tensor], keys: Optional[list[str]]):
  r = {}
  for k,v in data_dict.items():
    if keys==None or keys and k in keys:
      r[k] = v.mean(dim=0).item()
  return r

def get_std(data_dict: dict[str, th.Tensor], keys: Optional[list[str]]):
  r = {}
  for k,v in data_dict.items():
    if keys==None or keys and k in keys:
      r[k] = v.std(dim=0).item()
  return r

def view(data_dict: dict, fn_dict: Optional[dict] = None) -> OrderedDict:
  if not fn_dict:
    return OrderedDict(data_dict)
  r = OrderedDict()
  for k, v in data_dict.items():
    if k not in fn_dict:
      continue
    fn = fn_dict[k]
    if fn:
      r[k] = fn(v)
    else:
      r[k] = v
    continue
  return r


def load_all_data(data_path: str) -> Tuple[OrderedDict, OrderedDict]:
  def load(type: str) -> OrderedDict:
    dir_path: str = os.path.join(data_path, type) # ex: data/inputs
    data = OrderedDict()
    for f in os.listdir(dir_path):
      data[f.replace('.npy', '')] = load_tensor(dir_path, f)
    return data
  return load('input'), load('output')
    
    
def load_simple_move_data(data_path: str) -> Tuple[OrderedDict, OrderedDict]:
  inputs = OrderedDict({
    'dir': load_tensor(data_path, 'input', 'dir.npy')[:, 0],
    'vel': load_tensor(data_path, 'input', 'vel.npy')[:, 0]
  })
  outputs = OrderedDict({
    'dpos': load_tensor(data_path, 'output', 'dpos.npy')[:, 0],
    'vel': load_tensor(data_path, 'output', 'vel.npy')[:, 0]
  })
  return inputs, outputs
  
  
def write_same_line(*args):
  sys.stdout.write(''.join(args))
        
   
class KeyValueDataset(Dataset):
  def __init__(self, inputs, outputs):
    self.inputs = inputs
    self.outputs = outputs
    self.length = 0
    for v in inputs.values():
      self.length = len(v)
      break
  
  def __len__(self):
    return self.length

  def __getitem__(self, idx):
    return get_slice(self.inputs, idx), get_slice(self.outputs, idx)
  
  def select_indexes(self, fn):
    idx = []
    for i in range(len(self)):
      inputs, outputs = self[i]
      if fn(inputs, outputs):
        idx.append(i)
    return idx
  
  def select_indexes_parallel(self, fn):
    evaluation = fn(self.inputs, self.outputs)
    idx: th.Tensor = th.nonzero(evaluation)
    return idx.view(len(idx)).tolist()
    
  def select(self, fn):
    return KeyValueDataset(*self[self.select_indexes(fn)])
  
  def select_parallel(self, fn):
    return KeyValueDataset(*self[self.select_indexes_parallel(fn)])
  
  def random_split(self, prop, generator=None):
    a, b = random_split(self, prop, generator)
    return KeyValueDataset(*a[:]), KeyValueDataset(*b[:])
  
  def condition_split(self, fn):
    idx = self.select_indexes(fn)
    idx_ = list(set(range(len(self))).difference(idx))
    return KeyValueDataset(*self[idx]), KeyValueDataset(*self[idx_])
  
  def condition_split_parallel(self, fn):
    idx = self.select_indexes_parallel(fn)
    idx_ = list(set(range(len(self))).difference(idx))
    return KeyValueDataset(*self[idx]), KeyValueDataset(*self[idx_])
  
  def view(self, input_view=None, output_view=None):
    return KeyValueDataset(view(self.inputs, input_view), view(self.outputs, output_view))
  
  def avg(self, input_keys=None, output_keys=None):
    return get_avg(self.inputs, input_keys), get_avg(self.outputs, output_keys)
  
  def std(self, input_keys=None, output_keys=None):
    return get_std(self.inputs, input_keys), get_std(self.outputs, output_keys)
  
  def no_repetition(self):
    seen = set()
    idx = []
    for i in range(len(self)):
      inputs, _ = self[i]
      a = []
      for v in inputs.values():
        a.append(v.item())
      key = tuple(a)
      if key not in seen:
        idx.append(i)
        seen.add(key)
    return KeyValueDataset(*self[idx])
  
  def range(self, start=None, end=None):
    if not start:
      start = 0
    if end == None:
      end = len(self)
    start %= len(self)
    end %= len(self)
    return KeyValueDataset(*self[start: end])
  
  def print_yaml_like(self):
    def print_data(data):
      print(data)
      for k, v in data.items():
        print(k)
        print(v)
    for inputs, outputs in self:
      print('---')
      print_data(inputs)
      print()
      print_data(outputs)
      
  def print_table(self):
    def print_header(keys):
      for k in keys:
        write_same_line((k).rjust(7))
        write_same_line(' ')
    
    print_header(self.inputs.keys())
    write_same_line('|')
    print_header(self.outputs.keys())
    write_same_line('\n')
    
    def print_data(data):
      write_same_line(*['{:.4f}'.format(v.item()).rstrip('0').rjust(7) + ' ' for v in data.values()])
      # write_same_line(*['{}'.format(v.item()) + ' ' for v in data.values()])
    for inputs, outputs in self:
      print_data(inputs)
      write_same_line(' ')
      print_data(outputs)
      sys.stdout.write('\n')

    sys.stdout.flush()

move_dataset = KeyValueDataset(*load_all_data('data/only_side_moves')).view({
    'vel': lambda x: x[:, 0],
    'dir': lambda x: x[:, 0]
  },{
    'vel': lambda x: x[:, 0],
    'dpos': lambda x: x[:, 0]
  })


# move_dataset.print_table()
# move_dataset.print_yaml_like()
# print(move_dataset.avg())
# print(move_dataset.std())

# nodir_dataseet, dir_dataset = move_dataset.condition_split(lambda inputs, outputs: inputs['dir']==0)
nodir_dataseet, dir_dataset = move_dataset.condition_split_parallel(lambda inputs, outputs: inputs['dir']==0)
print(nodir_dataseet.avg(['dir'], []))
print(nodir_dataseet.std(['dir'], []))
print(dir_dataset.avg(['dir'], []))
print(dir_dataset.std(['dir'], []))
print(len(move_dataset))
print(len(nodir_dataseet))
print(len(dir_dataset))

# move_dataset.condition_split_parallel(lambda inputs, outputs: inputs['dir']==0)

In [None]:
import math
import matplotlib.pyplot as plt

# inputs, outputs = move_dataset[:]
# for e in zip(inputs['vel'], outputs['dpos']):
#   print(*['{:.4f}'.format(ee.item()) for ee in e])


fig, ax = plt.subplots()
inputs, outputs = move_dataset[:]
# x, idx = th.sort(inputs['vel'])
# y = th.gather(outputs['vel'], dim=0, index=idx)

# for e in zip(x, y):
#   print(*[ee.item() for ee in e])
  
ax.scatter(inputs['vel'], outputs['vel'])

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

from typing import Optional

class Vint:
  def __call__(self, v: Optional[int]=None) -> int:
    if v:
      self.v: int = v
      return v
    
    if hasattr(self, 'v'):
      return self.v
    
    return 0
    
    


In [None]:
def map_reduce(a, kfn, redfn):
  m = {}
  for x in a:
    k = kfn(x)
    elem = m.get(k, None)
    m[k] = redfn(elem, x)
  return m

In [None]:
from torch.optim.lr_scheduler import ReduceLROnPlateau

from typing import Callable


def save_if_better(id, min_loss, state_dict):
  # print('minloss', min_loss)
  if id in saved_models and min_loss >= saved_models[id]['loss']:
    return
  # print('Saving better model')
  saved_models[id] = {
    'loss': min_loss,
    'state': state_dict
  }

def create_model(mtype, load_state = True, **kargs):
  model = mtype(**kargs)
  if load_state and model.id in saved_models:
    saved = saved_models[model.id]
    try:
      model.load_state_dict(saved['state'])
      print('Loaded model state. loss:', saved['loss'])
    except:
      print('Saved state doesn\'t match model')
  return model

def get_lr(optimizer):
  for param_group in optimizer.param_groups:
    return float(param_group['lr'])
  raise Exception('No earning rate found')
     
def fit(model, X, Y, Xval, Yval, epochs=1000, optimizer=None, lossfn: Callable=th.nn.MSELoss(), patiance = 50, min_lr=1e-9):
  losses = []
  val_losses = []
  min_val_loss = 9999999999999999
  best_state = None
  try:  
    if not optimizer:
      optimizer = th.optim.AdamW(model.parameters(), lr=1e-2)
      
    lrDecaySch = ReduceLROnPlateau(optimizer, patience=patiance, verbose=True, eps=min_lr*0.1, threshold=1e-4)
    stop = False
    lr = get_lr(optimizer)
    print('Learning rate: ', lr)
    
    def get_loss(X, Y):
      if isinstance(X, dict):
        outputs = model(**X)
      else:
        outputs = model(X)
      return lossfn(Y, outputs)
    
    for i in range(epochs):
      if stop:
        break
      optimizer.zero_grad()
      
      loss = get_loss(X, Y)
      val_loss = get_loss(Xval, Yval)
      if i%10 == 0:
        print('epoch {}, lr: {:.4}, loss {:.6}, val_loss: {:.6}'.format(i, lr, loss.item(), val_loss.item()))
      
      loss.backward()
      optimizer.step()
      
      lrDecaySch.step(loss)
      lr = get_lr(optimizer)
      if lr < min_lr:
        stop=True
      
      curr_val_loss = val_loss.item()
      losses.append(loss.item())
      val_losses.append(curr_val_loss)
      if curr_val_loss < min_val_loss:
        min_val_loss = curr_val_loss
        best_state = model.state_dict()
        save_if_better(model.id, min_val_loss, best_state)
  except KeyboardInterrupt:
    print('Training interrupted')
  return losses, val_losses, min_val_loss, best_state

def plot_loss(loss, label, ax=None):
  if not ax:
    fig, ax = plt.subplots(figsize=(10, 4))
  loss = [(l) for l in loss]
  ax.plot(range(len(loss)), loss, lw=0.5, label=label)
  ax.legend()

In [None]:
class ModelSideMoveStopVel(nn.Module):
  def __init__(self, n):
    super(ModelSideMoveStopVel, self).__init__()
    self.lin1 = nn.Linear(1, n)
    self.lin2 = nn.Linear(n, 1, bias=False)
    

  def forward(self, vel):
    x = self.lin1(vel[:, None])
    x = th.relu(x)
    x = self.lin2(x)
    return x.view(-1)

# class ModelSideMoveStopVel(nn.Module):
#   def __init__(self):
#     super(ModelSideMoveStopVel, self).__init__()
#     self.cap1 = nn.Parameter(th.tensor(0.2))
#     self.cap2 = nn.Parameter(th.tensor(0.2))
    

#   def forward(self, vel):
#     return th.relu(vel + self.cap1) - th.relu(-vel + self.cap2)
    

inputs, outputs = move_dataset[:]

# inputs['vel'] = inputs['vel'][:, None]
# outputs['vel'] = outputs['vel'][:, None]
with th.no_grad():
  model = ModelSideMoveStopVel(8)
  preds = model(inputs['vel'])
  print('preds:', preds.shape)
  print('outputs:', outputs['vel'].shape)
  print('loss:', th.nn.MSELoss()(preds, outputs['vel']))
  print(model.state_dict())

In [None]:
# model = create_model(Model2)
model_vel = ModelSideMoveStopVel(4)
losses, min_loss, best_state = fit(model_vel, inputs['vel'], outputs['vel'], lossfn=nn.L1Loss(), epochs=10000, optimizer=th.optim.Adam(model_vel.parameters(), lr=0.1))
# save_if_better(type(model), min_loss, best_state)
plot_loss(losses)

In [None]:
# model_vel.lin1.weight = nn.Parameter(th.tensor([[1], [-1]], dtype=th.float32))
# model_vel.lin1.bias = nn.Parameter(th.tensor([0.4, 0.1], dtype=th.float32))
# model_vel.lin2.weight = nn.Parameter(th.tensor([[0.1,  0.5]], dtype=th.float32))
# print(model_vel.state_dict())

def plot_fit(model, inputs, outputs):
  with th.no_grad():
    fig, ax = plt.subplots()
    ax.scatter(inputs, outputs)
    v0 = th.arange(-2, 2, 0.1)
    v1 = model(v0)
    ax.plot(v0, v1, color='#C44')

In [None]:
def plot_error_histogram(model, inputs, outputs, res = 0.1, ax=None, title=''):
  with th.no_grad():
    if isinstance(inputs, dict):
      pred = model(**inputs)
    else:
      pred = model(inputs)
    e = pred - outputs
    hist = map_reduce(e, lambda x: int(x/res)*res, lambda acc, x: acc+1 if acc else 1)

    if not ax:
      fig, ax = plt.subplots()
    
    ax.set_title(' '.join((title, model.id)))
    ax.set_ylabel('%')
    ax.set_xlabel('error')
    # print(hist.keys())
    ax.bar(hist.keys(), [v/len(e) * 100 for v in hist.values()], width=res)
    ax.set_xlim(min(hist.keys()) -res -1, max(hist.keys())+res+1)

# plot_error_histogram(model_vel, inputs['vel'], outputs['vel'])

In [None]:
with th.no_grad():
  def within_error(inputs, outputs):
    e = model_vel(inputs['vel'][None]) - outputs['vel']
    l = e**2
    return l.item() < 0.3

  filtered_dataset = move_dataset.select(within_error)

  inputs, outputs = filtered_dataset[:]
  fig, ax = plt.subplots()
  ax.scatter(inputs['vel'], outputs['vel'])

In [None]:
losses, min_loss, best_state = fit(model_vel, inputs['vel'], outputs['vel'], lossfn=nn.L1Loss(), epochs=10000, optimizer=th.optim.Adam(model_vel.parameters(), lr=0.1))
# save_if_better(type(model), min_loss, best_state)
plot_loss(losses)

In [None]:
model_pos = ModelSideMoveStopVel(8)
losses, min_loss, best_state = fit(model_pos, inputs['vel'], outputs['dpos'], lossfn=nn.L1Loss(), epochs=10000, optimizer=th.optim.Adam(model_pos.parameters(), lr=0.1))
plot_loss(losses)

In [None]:
plot_fit(model_pos, inputs['vel'], outputs['dpos'])

In [None]:
plot_error_histogram(model_pos, inputs['vel'], outputs['dpos'])

In [None]:
move_dataset = KeyValueDataset(*load_all_data('data/only_side_moves')).view({
    'vel': lambda x: x[:, 0],
    'dir': lambda x: x[:, 0]
  },{
    'vel': lambda x: x[:, 0],
    'dpos': lambda x: x[:, 0]
  })

In [None]:
move_dataset.print_table()

In [None]:
class ModelSideMoveVel(nn.Module):
  def __init__(self, n, act=th.relu):
    super(ModelSideMoveVel, self).__init__()
    self.lin1 = nn.Linear(2, n)
    self.lin2 = nn.Linear(n, 1, bias=False)
    self.act = act
    

  def forward(self, vel, dir):
    x = th.cat([vel[:, None], dir[:, None]], dim=1)
    x = self.lin1(x)
    x = self.act(x)
    x = self.lin2(x)
    return x.view(-1)

In [None]:
model_vel = ModelSideMoveVel(4, th.relu)
losses, min_loss, best_state = fit(model_vel, inputs, outputs['vel'], lossfn=nn.L1Loss(), epochs=10000, 
                                   optimizer=th.optim.Adam(model_vel.parameters(), lr=0.1))
# save_if_better(type(model), min_loss, best_state)
plot_loss(losses)

In [None]:
plot_error_histogram(model_vel, inputs, outputs['vel'])

In [None]:
move_dataset = KeyValueDataset(*load_all_data('data/side_moves_jump')).view({
    'vel': lambda x: x[:, 0],
    'dir': lambda x: x[:, 0],
    'jump': None,
    'onledge': None,
    'wall': None
  },{
    'vel': lambda x: x[:, 0],
    'dpos': lambda x: x[:, 0]
  })

In [None]:
from torch.utils.data import random_split

# generator = th.Generator().manual_seed(130)
train_dataset, val_dataset = move_dataset.random_split([0.7, 0.3])
X_train, Y_train = train_dataset[:]
X_val, Y_val = val_dataset[:]

print()
print(train_dataset.avg(['vel'], ['dpos']))
print(val_dataset.avg(['vel'], ['dpos']))
print(train_dataset.std(['vel'], ['dpos']))
print(val_dataset.std(['vel'], ['dpos']))

In [None]:
class ModelSideMoveVel(nn.Module):
  def __init__(self, dims, act = th.nn.ReLU):
    super(ModelSideMoveVel, self).__init__()
    self.layers = nn.Sequential()
    for i, d in enumerate(dims):
      self.layers.append(nn.LazyLinear(d))
      self.layers.append(act())
      self.layers.append(nn.LazyBatchNorm1d())
    
    self.last = nn.LazyLinear(1)
    self.act = act
    self.id = '_'.join(['ModelSideMoveVel', *[str(d) for d in dims], str(act.__name__)])
    
  def forward(self, vel, dir, wall, onledge, jump):
    vel = vel[:, None]
    # print(vel.shape)
    dir = dir[:, None]
    # print(dir.shape)
    onledge = onledge[:, None]
    # print(onledge.shape)
    wall = wall.view(len(wall), -1)
    # print(wall.shape)
    input = th.cat([vel, dir, wall, onledge, jump], dim=1)
    # x = input
    # for l in self.layers:
    #   x = l(x)
    #   x = self.act(x)
    #   print(x.shape)
    #   x = th.cat([input, x], dim=1)
    x = self.layers(input)
    x = self.last(x)
    
    return x.view(-1)

model = ModelSideMoveVel([16, 16], act=th.nn.ReLU)
# model.to('cuda')
print(model.id)
plot_error_histogram(model, X_val, Y_val['dpos'], res=0.01)

In [None]:
model = create_model(ModelSideMoveVel, dims=[32], act=th.nn.ReLU, load_state=True)
device = 'cpu'
# model.to(device)
# def move_to_device(data, device):
#   for k, v in data.items():
#     data[k] = v.to(device)
    
# move_to_device(X_train, device)
# move_to_device(Y_train, device)
# move_to_device(X_val, device)
# move_to_device(Y_val, device)

# print(model.state_dict())
# print(th.nn.L1Loss()(model(**Xval), Yval['vel']))

var = 'dpos'
# del saved_models[model.id]
losses, val_losses, min_loss, best_state = fit(model, X_train, Y_train[var], X_val, Y_val[var], lossfn=nn.L1Loss(), epochs=10000, 
    optimizer=th.optim.Adam(model.parameters(), lr=0.1))

fig, ax = plt.subplots(figsize=(10, 4))
  
plot_loss(losses, 'train', ax)
plot_loss(val_losses, 'val', ax)
# print(val_losses)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(14, 5))
model = create_model(ModelSideMoveVel, dims=[32], act=th.nn.ReLU)
model.eval()
# model.to('cuda')
plot_error_histogram(model, X_val, Y_val[var], title='val', res=0.01, ax=ax[0])
plot_error_histogram(model, X_train, Y_train[var], title='train', res=0.01, ax=ax[1])

In [None]:

def within_error(inputs, outputs):
  e = model(**inputs) - outputs[var]
  # print(e.shape)
  l = abs(e)
  r = l < 0.01
  # print(len(r))
  return r

# print(len(train_dataset))
l_train_dataset, h_train_dataset = train_dataset.condition_split_parallel(within_error)
l_val_dataset, h_val_dataset = val_dataset.condition_split_parallel(within_error)

# print(len(l_train_dataset))
# print(len(h_train_dataset))
# print(len(l_val_dataset))
# print(len(h_val_dataset))

X, Y = l_train_dataset[:]
print(X['wall'].shape)
plot_error_histogram(model, X, Y[var], title='val', res=0.02)

In [None]:
X_h_train, Y_h_train = h_train_dataset[:]
X_h_val, Y_h_val = h_val_dataset[:]
model = create_model(ModelSideMoveVel, dims=[4], act=th.nn.ReLU)
losses, val_losses, min_loss, best_state = fit(model, X_h_train, Y_h_train['dpos'], X_val, Y_val[var], lossfn=nn.L1Loss(), epochs=10000, 
    optimizer=th.optim.Adam(model.parameters(), lr=0.1))

fig, ax = plt.subplots(figsize=(10, 4))
  
plot_loss(losses, 'train', ax)
plot_loss(val_losses, 'val', ax)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(14, 5))
model = create_model(ModelSideMoveVel, dims=[4], act=th.nn.ReLU)
model.eval()
# model.to('cuda')
plot_error_histogram(model, X_val, Y_val[var], title='val', res=0.01, ax=ax[0])
plot_error_histogram(model, X_train, Y_train[var], title='train', res=0.01, ax=ax[1])