In [None]:
"""
load data
"""

import os
import copy

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

import graph
import dataset
import csvio
import interpolation
import data_processing as dp
import prediction as pred
from model import Transformer

motions, headers = csvio.load_data('input_data/train/train.csv')
bow_motions = motions[0:11]
wav_motions = motions[11:33]
run_motions = motions[33:45] + motions[56:67]
frt_motions = motions[45:56]
wlk_motions = motions[67:91]
bck_motions = motions[91:103]
rgt_motions = motions[103:115]
lft_motions = motions[115:127]

os.makedirs("model", exist_ok=True)
os.makedirs("graph", exist_ok=True)
os.makedirs("output_data", exist_ok=True)

In [None]:
"""
setting
"""
interval = 45
sections = 1
input_window = interval * sections + 1
name = "45_run"
motion_list = run_motions
rates = np.linspace(0.9, 1.1, 5)
batch_size = 8
device = torch.device("cuda")

In [None]:

"""
train
"""


def train(source, target):
    """
    source: (data_size, window, feature)
    target: (data_size, window, feature)
    """
    model.train()
    total_loss = 0.
    for _ in range(0, len(source) - batch_size, batch_size):
        data, targets = dataset.get_both_batch(source, target, batch_size, input_window)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, targets)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(range(0, len(source) - 1, batch_size))


def evaluate(eval_model, source, target):
    eval_model.eval()
    eval_batch_size = min(64, len(source))
    with torch.no_grad():
        src, tgt = dataset.get_both_batch(source, target, eval_batch_size, input_window)
        out = eval_model(src)
        loss = criterion(out, tgt)
    return loss.item()


train_motions, valid_motions = dataset.get_train_valid(motion_list)

for model_num in range(3):
    model_path = f"model/{name}_{model_num}.pth"
    epochs = 20
    best_valid_loss = float("inf")
    model = Transformer(interval).to(device).double()
    lr = 0.005
    criterion = nn.MSELoss()
    optimizer = torch.optim.RAdam(model.parameters(), lr=lr, weight_decay=0.00001)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.995)

    print(f"====== {model_path} ======")
    train_source, train_target = dataset.make_dataset(train_motions, interval, input_window, rates)
    valid_source, valid_target = dataset.make_dataset(valid_motions, interval, input_window, rates)
    train_loss_list = []
    valid_loss_list = []
    for epoch in range(1, epochs + 1):
        train_loss = train(train_source, train_target)
        valid_loss = evaluate(model, valid_source, valid_target)
        train_loss_list.append(train_loss)
        valid_loss_list.append(valid_loss)
        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            torch.save(model.state_dict(), model_path)
        print(f'epoch {epoch:3d} | train {train_loss:5.5f} | valid {valid_loss:5.5f}')
        scheduler.step()
        graph.plot_loss(train_loss_list, valid_loss_list)
    print("")


In [None]:
"""
test
"""

def plot_data(lerp_data, pred_data, target_data, axis=1):
    for marker in range(21):
        plt.figure(figsize=(12, 8))
        plt.plot(lerp_data[:, marker, axis], label="lerp", c='g', linestyle='dashed')
        plt.plot(pred_data[:, marker, axis], label="pred", c='r')
        plt.plot(target_data[:, marker, axis], label="target", c='b')
        plt.legend()
        plt.show()
        plt.close()

all_lerp_motions = []
all_pred_motions = []

input_window = interval * sections + 1

model = Transformer(interval).to(device).double()

print(f"====== {name} ======")
_, valid_motions = dataset.get_train_valid(motion_list)

weights = pred.search_best_weights(model, valid_motions, name, interval, input_window)
lerp_motions = copy.deepcopy(valid_motions)
pred_motions = copy.deepcopy(valid_motions)
total_loss = 0.0
for index, valid_motion in enumerate(valid_motions):
    if valid_motion.data.shape[0] < input_window:
        continue
    source_data, target_data = dataset.make_predict_data(valid_motion, interval)
    source_data = dp.lost(source_data, interval)

    print("motion:", valid_motion.name, valid_motion.data.shape)

    lerp_data = interpolation.linear_interpolate(source_data, interval)
    pred_data = pred.predict_with_weights(model, source_data, weights, name, interval, input_window)

    lerp_loss = dp.calc_loss(lerp_data, target_data)
    pred_loss = dp.calc_loss(pred_data, target_data)
    print(f'\tlerp loss: {lerp_loss:5.5f}')
    print(f'\tpred loss: {pred_loss:5.5f}')
    total_loss += pred_loss

    lerp_motions[index].data = lerp_data
    pred_motions[index].data = pred_data

    if index == 0:
        plot_data(lerp_data, pred_data, target_data)

print("total_loss:", total_loss)
all_lerp_motions = all_lerp_motions + lerp_motions
all_pred_motions = all_pred_motions + pred_motions

csvio.write_csv("output_data/test_lerp.csv", headers, all_lerp_motions)
csvio.write_csv("output_data/test_pred.csv", headers, all_pred_motions)