# Model 4: Two Stage Model

In [None]:
import torch
import numpy as np

In [None]:
device = ("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

In [None]:
# Import local modules from 'src/utils' as package 'utils'
import sys; sys.path.insert(0, '../')

## Load ParallelTrajectoryDatasets

In [None]:
from pathlib import Path
from utils.file_io import read_parallel_trajectory_datasets

In [None]:
data_path = Path("../../data/boomer_data_simulation/")

feature_columns = [
    'left_boom_base_yaw_joint', 'left_boom_base_pitch_joint', 'left_boom_main_prismatic_joint', 'left_boom_second_roll_joint',
    'left_boom_second_yaw_joint', 'left_boom_top_pitch_joint', 'cable1_property(length,youngsmodule(bend,twist))', 
    'cable2_property(length,youngsmodule(bend,twist))', 'cable3_property(length,youngsmodule(bend,twist))'
]

label_features = [
    ('cable1_lowest_point', np.array([0, 1, 2], dtype=np.int64)),
    ('cable2_lowest_point', np.array([0, 1, 2], dtype=np.int64)),
    ('cable3_lowest_point', np.array([0, 1, 2], dtype=np.int64))
]

normalized_features = [
    ('cable1_property(length,youngsmodule(bend,twist))', np.array([1, 2], dtype=np.int64)),
    ('cable2_property(length,youngsmodule(bend,twist))', np.array([1, 2], dtype=np.int64)),
    ('cable3_property(length,youngsmodule(bend,twist))', np.array([1, 2], dtype=np.int64))
]

In [None]:
train_set, test_set, validation_set, visualization_set = read_parallel_trajectory_datasets(data_path, 0.85, 0.10, 0.045, 0.005, 256, 
                                                                                    feature_columns=feature_columns, 
                                                                                    label_features=label_features,
                                                                                    normalized_features=normalized_features)

In [None]:
features, labels, last_indices = test_set[257] 
print(features.shape, labels.shape, last_indices)
input_shape, output_shape = features.shape[-1], labels.shape[-1]

## Define Transformer Encoder Model

In [None]:
from torch import nn, Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from umap import UMAP
import math

### Transformer positional encoding

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 1024) -> None:
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        cosine_embed = torch.cos(position * div_term)
        pe[:, 0, 1::2] = cosine_embed if d_model % 2 == 0 else cosine_embed[:, :-1]
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x =+ self.pe[:x.size(0)]
        return self.dropout(x)

In [None]:
class TransformerEncoderModel(nn.Module):
    def __init__(self, num_heads: int, model_dim: int, feedforward_hidden_dim: int,
                 num_encoder_layers: int = 6, transformer_dropout: float = 0.1, pos_encoder_dropout: float = 0.25,
                 downprojection: bool = False, projection_num_neighbors: int = 5) -> None:
        super().__init__()
        self.model_type = 'Transformer'
        self.total_epochs = 0
        self.model_dim = model_dim
        self.downprojection = downprojection
        if self.downprojection:
            self.create_projection(projection_num_neighbors)
        encoder_layers = TransformerEncoderLayer(model_dim, num_heads, feedforward_hidden_dim, transformer_dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, num_encoder_layers)   
        self.pos_encoder = PositionalEncoding(model_dim, pos_encoder_dropout)

    def create_projection(self, projection_num_neighbors: int):
        self.projection_function = UMAP(n_components=self.model_dim, n_neighbors=projection_num_neighbors)            

    def forward(self, source: Tensor, source_msk: Tensor = None) -> Tensor:
        # expect input shape to be (S, N, E) with S being the sequence length, N batch size and, E the input dimensionality
        source = self.project(source)
        encoded = self.pos_encoder(source)
        return self.transformer_encoder(encoded, source_msk)
    
    def project(self, source: Tensor) -> Tensor:
        if self.downprojection:
            return self.projection_function.transform(source)
        return source
    

class ParallelEncoderModel(nn.Module):
    def __init__(self, num_decoders: int, num_heads: int, model_dim: int, feedforward_hidden_dim: int, output_dim: int,
                 num_encoder_layers: int = 6, transformer_dropout: float = 0.1, pos_encoder_dropout: float = 0.25,
                 downprojection: bool = False, projection_num_neighbors: int = 5) -> None:
        super().__init__()
        self.model_type = 'Transformer'
        self.total_epochs = 0
        self.encoder = TransformerEncoderModel(num_heads, model_dim, feedforward_hidden_dim, 
                                               num_encoder_layers, transformer_dropout, pos_encoder_dropout,
                                               downprojection=downprojection, projection_num_neighbors=projection_num_neighbors)
        self.decoders = nn.ModuleList([nn.Linear(model_dim, output_dim) for i in range(num_decoders)])
        self.activation = nn.ReLU()

    def forward(self, source: Tensor, source_mask: Tensor = None) -> Tensor:
        decoded = []
        for i, decoder in enumerate(self.decoders):
            trajectory_source = source[i, :, :, :]
            trajectory_source = self.encoder(trajectory_source, source_mask)
            decoded_trajectory = decoder(trajectory_source)
            decoded_trajectory = self.activation(decoded_trajectory)
            decoded.append(decoded_trajectory)
        return torch.stack(decoded, dim=0)
    
    @property
    def downprojection(self):
        return self.encoder.downprojection


In [None]:
encoder_model = ParallelEncoderModel(2, 3, input_shape, 64, output_shape).to(device) 

## Model Training Step 1

### Load parameters, functions, and Dataloader

In [None]:
import os

from typing import Any, Tuple

from torch.utils.data import DataLoader
from dotenv import load_dotenv

from utils.file_io import save_model
from utils.file_io import define_dataloader_from_subset
from utils.evaluation import compute_loss_on
from utils.optimizer import rate

In [None]:
model_path = Path("../../models/encoder/").absolute()

In [None]:
dotenv_path = model_path / ".env"
load_dotenv(dotenv_path=dotenv_path)

batch_size = int(os.getenv("BATCH_SIZE"))
num_epochs = int(os.getenv("NUM_EPOCHS"))

In [None]:
def get_optimizer_function_and_learning_rate_scheduler(model: nn.Module, model_size: int, warmup_steps: int, factor: float = 1) -> Tuple[Any, Any]:
    optimizer = torch.optim.AdamW(model.parameters(), lr=1)
    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda step: rate(step, model_size, factor, warmup_steps))
    return optimizer, lr_scheduler


def get_loss_function() -> nn.Module:
    return torch.nn.MSELoss()

In [None]:
train_dataloader, validation_dataloader, test_dataloader = define_dataloader_from_subset(train_set, validation_set, test_set, batch_size=batch_size, shuffle=True)

### Define Train Functions

#### Feature Downprojection

In [None]:
from torch.utils.data import DataLoader

In [None]:
def train_downprojection(projection: UMAP, train_dataloader: DataLoader) -> None:
    feature_vectors = []
    for features, labels in train_dataloader:
        features = torch.flatten(features, start_dim=0, end_dim=-2)
        feature_vectors.append(features)
    features = torch.concat(feature_vectors, dim=0)
    projection = projection.fit(features)      

#### Model train methods

In [None]:
from typing import Callable

from ray import train as ray_train
from ray.train import Checkpoint

In [None]:
def train_epoch_parallel(train_dataloader: DataLoader, model, loss_function, optimizer, lr_scheduler,
                         device: torch.device, report_interval: int = 10):
    
    running_loss = 0
    last_loss = 0
    
    for i, (inputs, true_values) in enumerate(train_dataloader):
        
        inputs = inputs.to(device)
        true_values = true_values.to(device)
    
        inputs_shape, true_values_shape = inputs.size(), true_values.size()
        inputs = inputs.view(inputs_shape[1], inputs_shape[2], inputs_shape[0], inputs_shape[3])
        true_values = true_values.view(true_values_shape[1], true_values_shape[2], true_values_shape[0], true_values_shape[3])
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, true_values)
        running_loss += loss
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        optimizer.zero_grad()
        outputs = model(inputs)
    
        if i % report_interval == report_interval - 1:
            last_loss = running_loss / report_interval
            print(f"batch {i + 1}, Mean Squared Error: {last_loss}")
            running_loss = 0

    return last_loss

In [None]:
def train(train_epoch_func: Callable, epochs: int, train_dataloader: DataLoader, validation_dataloader: DataLoader, model: nn.Module, 
          loss_function, optimizer, lr_scheduler, checkpoint_path: Path, device: torch.device = 'cpu',
          report_interval: int = 10, tune: bool = False) -> nn.Module:
    
    best_val_loss = float("inf")
    if model.downprojection:
        print("Fitting downprojection!")
        train_downprojection(model.projection_function, train_dataloader)

    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    model.to(device)

    if tune:
        checkpoint = ray_train.get_checkpoint()

        if checkpoint:
            with checkpoint.as_directory() as checkpoint_dir:
                model_state = torch.load(os.path.join(checkpoint_dir, "checkpoint.pt"))
                model.load_state_dict(model_state)

    for epoch in range(model.total_epochs, epochs):
        print(f"Epoch: {epoch + 1}")

        model.train(True)
        avg_loss = train_epoch_func(train_dataloader, model, loss_function, optimizer, lr_scheduler, device, report_interval)
        model.eval()

        with torch.no_grad():
            avg_val_loss = compute_loss_on(validation_dataloader, model, loss_function, reshape=True, device=device)

        print(f"Loss on train: {avg_loss}, loss on validation: {avg_val_loss}")

        model.total_epochs += 1
    

        if avg_val_loss < best_val_loss or tune:
            best_val_loss = avg_val_loss            
            
            torch.save(model.state_dict(), checkpoint_path / "checkpoint.pt")

        if tune:
            ray_train.report(metrics={ "loss": float(avg_val_loss) }, checkpoint=Checkpoint.from_directory(checkpoint_path))

    return model

## Train the model with optuna hyperparameter tuning

In [None]:
import ray
from ray import tune
from ray.tune.schedulers import ASHAScheduler
from ray.tune.search.optuna import OptunaSearch
from typing import Dict

In [None]:
def parameter_train(parameter: Dict, train_epochs: int, train_data: DataLoader, validation_data: DataLoader, model_input_shape: int,
                    model_output_shape: int, checkpoint_path: Path, device: torch.device) -> None:
    features. _ = train_dataloader.dataset[0]
    parallel_trajectories = features.shape[0]
    run_id = ray_train.get_context().get_trial_id()
    run_checkpoint = checkpoint_path / run_id
    run_checkpoint.mkdir(parents=True)

    model = ParallelEncoderModel(parallel_trajectories, parameter["model_dim"], parameter["model_dim"], parameter["feedforward_dim"], model_output_shape,
                                 parameter["encoder_layer"], parameter["transformer_dropout"], parameter["pos_encoder_dropout"],
                                 ).to(device)

    optimizer, lr_scheduler = get_optimizer_function_and_learning_rate_scheduler(model, parameter["model_dim"], warmup_steps=parameter["warmup_steps"])
    loss_function = get_loss_function()

    _ = train(train_epoch_parallel, train_epochs, train_data, validation_data, model, loss_function, optimizer, lr_scheduler, run_checkpoint, device, report_interval=50, tune=True)

In [None]:
learning_rate_radius = 1e-3
batch_size_radius = 10
num_samples = 100

In [None]:
parameter_space = {
    "batch_size": tune.choice(list(range(batch_size - batch_size_radius, batch_size + batch_size_radius, 4))),
    "model_dim": tune.choice([9]),
    "downprojection": tune.choice([True]),
    "projection_num_neighbors": tune.choice(list(range(2, 10))),
    "warmup_steps": tune.choice(list(range(1000, 2000, 200))),
    "feedforward_dim": tune.choice([32, 64, 128]),
    "encoder_layer": tune.choice([1, 2, 3]),
    "transformer_dropout": tune.uniform(0.1, 0.5),
    "pos_encoder_dropout": tune.uniform(0.1, 0.5)
}

In [None]:
scheduler = ASHAScheduler(
    metric="loss",
    mode="min",
    max_t=num_epochs
)

In [None]:
search_alg = OptunaSearch(
    metric="loss",
    mode="min"
) 

In [None]:
import utils
if ray.is_initialized():
    ray.shutdown()

ray.init(runtime_env={ "py_modules": [utils] })

In [None]:
ray_resources_manager = tune.with_resources(
    trainable=lambda param: parameter_train(param, num_epochs, train_dataloader, validation_dataloader, input_shape, output_shape, model_path, device),
    # See: https://stackoverflow.com/questions/58967793/what-is-the-way-to-make-tune-run-parallel-trials-across-multiple-gpus
    resources={ "cpu": 3, "gpu": 0.25 }
)

tuner = tune.Tuner(
    ray_resources_manager,
    param_space=parameter_space,
    tune_config=tune.TuneConfig(
        scheduler=scheduler,
        search_alg=search_alg,
        num_samples=num_samples
    )
)

In [None]:
results = tuner.fit()

In [None]:
if ray.is_initialized():
    ray.shutdown()

In [None]:
# Save as csv file
results.get_dataframe().to_csv(model_path / "trail_grid.csv")

In [None]:
best_parallel_result = results.get_best_result("loss", "min")
best_parallel_checkpoint = best_parallel_result.get_best_checkpoint("loss", "min")

best_parallel_model_dict = torch.load(f"{best_parallel_checkpoint.path}/checkpoint.pt")

In [None]:
print(f"Best trail by loss value {best_parallel_result.metrics['loss']}", "\n------")
for i in best_parallel_result.config:
    print(f"Best trail: {i} value {best_parallel_result.config[i]}")

In [None]:
best_parallel_model = ParallelEncoderModel(
    2,
    best_parallel_result.config["model_dim"], 
    best_parallel_result.config["model_dim"], 
    best_parallel_result.config["feedforward_dim"], 
    output_shape,
    num_encoder_layers=best_parallel_result.config["encoder_layer"],
    transformer_dropout=best_parallel_result.config["transformer_dropout"],
    pos_encoder_dropout=best_parallel_result.config["pos_encoder_dropout"]
)
best_parallel_model.load_state_dict(best_parallel_model_dict)
encoder = best_parallel_model.encoder

## Predict Length and Cable Properties from Sample Trajectory

In [None]:
def predict_cable_properties_from_complete_data(model: nn.Module, trajectory_features: Tensor, cable_properties_dim: int, 
                                                trajectory_lowpoints: Tensor, n_iterations: int = 100, lr: float = 1e-4, 
                                                loss_function=nn.MSELoss()):
    true_cable_properties = trajectory_features[:, :, :, -cable_properties_dim:]
    predicted_cable_properties = predict_cable_properties(model, trajectory_features[:, :, :, :-cable_properties_dim], trajectory_lowpoints, 
                                                          cable_properties_dim, n_iterations, lr, loss_function)
    return predicted_cable_properties, true_cable_properties


def predict_cable_properties(model: nn.Module, trajectory_features: Tensor, trajectory_lowpoints: Tensor, 
                             cable_properties_dim: int, n_iterations: int, lr: float, loss_function):
    model.eval()
    freeze_parameters(model)
    cable_properties = torch.rand(trajectory_features.shape[0], trajectory_features.shape[1], trajectory_features.shape[2], cable_properties_dim)
    loss = torch.autograd.Variable(torch.tensor(float('inf'), dtype=torch.float32), requires_grad=True)
    i = 0
    while loss > 2 or i < n_iterations:
        if i == 0:
            complete_input = torch.autograd.Variable(torch.concat([trajectory_features, cable_properties], dim=-1), requires_grad=True)
            optimizer = torch.optim.SGD([complete_input], lr=lr)
        optimizer.zero_grad() 
        predicted_lowpoints = model(complete_input)
        loss = loss_function(predicted_lowpoints, trajectory_lowpoints)
        loss = torch.autograd.Variable(loss, requires_grad=True)
        loss.backward()
        optimizer.step()
        #print(complete_input.grad)
        i += 1

    result = complete_input[:,:,:, -cable_properties_dim:]
    return result


def freeze_parameters(model: nn.Module) -> None:
    for param in model.parameters():
        param.requires_grad = False

In [None]:
trajectories, lowpoints = validation_set[0]
trajectories, lowpoints = trajectories.unsqueeze(2), lowpoints.unsqueeze(2)
print(trajectories.shape, lowpoints.shape)
predicted_cable_properties, true_cable_properties = predict_cable_properties_from_complete_data(
    best_parallel_model, trajectories, 9, lowpoints
)

print(torch.norm(true_cable_properties - predicted_cable_properties))

## Model Training Step 2

### Decoder Definition

In [None]:
from torch.nn import TransformerDecoder, TransformerDecoderLayer

In [None]:
class TransformerDecoderModel(nn.Module):
    def __init__(self, model_dim: int, num_heads: int, feedforward_dim: int, num_decoder_layers: int, 
                 pos_encoder: PositionalEncoding, transformer_dropout: float = 0.25) -> None:
        super().__init__()
        self.model_type = 'Transformer'
        self.total_epochs = 0
        self.model_dim = model_dim
        decoder_layer = TransformerDecoderLayer(model_dim, num_heads, dim_feedforward=feedforward_dim, dropout=transformer_dropout)
        self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers)
        self.pos_encoder = pos_encoder

    def forward(self, memory: Tensor, target: Tensor, target_mask: Tensor = None) -> Tensor:
        # expect input shape to be (S, N, E) with S being the sequence length, N batch size and, E the input dimensionality
        if target_mask is None:
            target_mask = nn.Transformer.generate_square_subsequent_mask(target.shape[0])
        target = self.pos_encoder(target)
        return self.decoder(memory, target, tgt_mask=target_mask)
    

class TransformerModel(nn.Module):
    def __init__(self, encoder: TransformerEncoderModel, decoder: TransformerDecoderModel) -> None:
        super().__init__()
        if encoder.model_dim != decoder.model_dim:
            raise ValueError("Both encoder and decoder must have the same model dimension!")
        self.model_type = 'Transformer'
        self.total_epochs = 0
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, source: Tensor, target: Tensor = None, source_mask: Tensor = None, 
                target_mask: Tensor = None, mode: str = "train") -> Tensor:
        # expect input shape to be (S, N, E) with S being the sequence length, N batch size and, E the input dimensionality
        # target_mask masks out all values right of the diagonal such that information from the target sequence cant bleed into the left hand side at training time
        if mode == "train":
            prediction = self.forward_train(source, target, source_mask, target_mask)
        elif mode == "generate":
            prediction = self.generate(source)
        else:
            raise ValueError("Invalid mode selected!")

        return prediction

    def forward_train(self, source: Tensor, target: Tensor, source_mask: Tensor = None, 
                      target_mask: Tensor = None) -> Tensor:
        if target is None:
            raise ValueError("In train mode a target sequence has to be provided!")
        
        memory = self.encoder(source, source_mask)
        return self.decoder(memory, target, target_mask)
    
    def generate(self, source: Tensor):
        memory = self.encoder(source)
        target = torch.zeroes_like(source)
        predicted = torch.zeros_like(source)

        for i, mem_batch in enumerate(memory):
            predicted[i, :, :] = self.decoder(mem_batch, target[i, :, :])
            target[i + 1, :, :] = predicted[i, :, :]
        
        return predicted
    
    @property
    def downprojection(self) -> bool:
        return self.encoder.downprojection 


### Model Train Functions

#### Encoder fixed training

#### Load datasets

In [None]:
from utils.file_io import read_trajectory_datasets

In [None]:
train_set, test_set, validation_set, visualization_set = read_trajectory_datasets(data_path, 0.85, 0.10, 0.045, 0.005, 256, 
                                                                                  feature_columns=feature_columns, label_features=label_features)
train_dataloader, validation_dataloader, test_dataloader = define_dataloader_from_subset(train_set, validation_set, test_set, batch_size=batch_size, shuffle=True)

#### Define train methods

In [None]:
def train_epoch(train_dataloader: DataLoader, model, loss_function, optimizer, lr_scheduler,
                device: torch.device, report_interval: int = 1000):
    
    running_loss = 0
    last_loss = 0
    
    for i, (inputs, true_values) in enumerate(train_dataloader):
        
        inputs = inputs.to(device)
        true_values = true_values.to(device)
    
        inputs_shape, true_values_shape = inputs.size(), true_values.size()
        inputs = inputs.view(inputs_shape[1], inputs_shape[0], inputs_shape[2])
        true_values = true_values.view(true_values_shape[1], true_values_shape[0], true_values_shape[2])
        optimizer.zero_grad()
        outputs = model(inputs, true_values)
        loss = loss_function(outputs, true_values)
        running_loss += loss
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
    
        if i % report_interval == report_interval - 1:
            last_loss = running_loss / report_interval
            print(f"batch {i + 1}, Mean Squared Error: {last_loss}")
            running_loss = 0

    return last_loss

In [None]:
def parameter_train_decoder(encoder: TransformerEncoderModel, parameter: Dict, train_epochs: int, train_data: DataLoader, 
                            validation_data: DataLoader, checkpoint_path: Path, 
                            device: torch.device) -> None:
    features. _ = train_dataloader.dataset[0]
    run_id = ray_train.get_context().get_trial_id()
    run_checkpoint = checkpoint_path / run_id
    run_checkpoint.mkdir(parents=True)
    decoder = TransformerDecoderModel(encoder.model_dim, encoder.model_dim, parameter["feedforward_dim"], parameter["decoder_layer"], encoder.pos_encoder, parameter["transformer_dropout"])
    model = TransformerModel(encoder, decoder)

    optimizer, lr_scheduler = get_optimizer_function_and_learning_rate_scheduler(decoder, encoder.model_dim, warmup_steps=parameter["warmup_steps"])
    loss_function = get_loss_function()

    _ = train(train_epoch, train_epochs, train_data, validation_data, model, loss_function, optimizer, lr_scheduler, run_checkpoint, device, report_interval=50, tune=True)

In [None]:
parameter_space = {
    "batch_size": tune.choice(list(range(batch_size - batch_size_radius, batch_size + batch_size_radius, 4))),
    "warmup_steps": tune.choice(list(range(1000, 2000, 200))),
    "feedforward_dim": tune.choice([32, 64, 128]),
    "decoder_layer": tune.choice([1, 2, 3]),
    "transformer_dropout": tune.uniform(0.1, 0.5),
    "pos_encoder_dropout": tune.uniform(0.1, 0.5)
}

In [None]:
scheduler = ASHAScheduler(
    metric="loss",
    mode="min",
    max_t=num_epochs
)

In [None]:
search_alg = OptunaSearch(
    metric="loss",
    mode="min"
) 

In [None]:
if ray.is_initialized():
    ray.shutdown()

ray.init(runtime_env={ "py_modules": [utils] })

In [None]:
ray_resources_manager = tune.with_resources(
    trainable=lambda params: parameter_train_decoder(encoder, params, num_epochs, train_dataloader, validation_dataloader, model_path, device),
    # See: https://stackoverflow.com/questions/58967793/what-is-the-way-to-make-tune-run-parallel-trials-across-multiple-gpus
    resources={ "cpu": 3, "gpu": 0.25 }
)

tuner = tune.Tuner(
    ray_resources_manager,
    param_space=parameter_space,
    tune_config=tune.TuneConfig(
        scheduler=scheduler,
        search_alg=search_alg,
         num_samples=num_samples
    )
)

In [None]:
results = tuner.fit()

In [None]:
if ray.is_initialized():
    ray.shutdown()

In [None]:
# Save as csv file
results.get_dataframe().to_csv(model_path / "trail_grid.csv")

In [None]:
best_decoder_result = results.get_best_result("loss", "min")
best_decoder_checkpoint = best_decoder_result.get_best_checkpoint("loss", "min")

best_decoder_model_dict = torch.load(f"{best_decoder_checkpoint.path}/checkpoint.pt")

In [None]:
print(f"Best trail by loss value {best_decoder_result.metrics['loss']}", "\n------")
for i in best_decoder_result.config:
    print(f"Best trail: {i} value {best_decoder_result.config[i]}")

In [None]:
encoder = TransformerEncoderModel(best_parallel_result.config["model_dim"], best_parallel_result.config["model_dim"], 
                                  best_parallel_result.config["feedforward_dim"], best_parallel_result.config["encoder_layer"], 
                                  best_parallel_result.config["transformer_dropout"], best_parallel_result.config["pos_encoder_dropout"], 
                                  best_parallel_result.config["downprojection"], best_parallel_result.config["projection_num_neighbors"])
decoder = TransformerDecoderModel(encoder.model_dim, encoder.model_dim, best_decoder_result.config["feedforward_dim"],
                                  best_decoder_result.config["decoder_layer"], encoder.pos_encoder, best_decoder_result.config["transformer_dropout"])
best_decoder_model = TransformerModel(
    encoder,
    decoder
)
best_decoder_model.load_state_dict(best_decoder_model_dict)

### Encoder/Decoder Joint Training

In [None]:
def parameter_train_encoder_decoder(parameter: Dict, train_epochs: int, train_data: DataLoader, 
                            validation_data: DataLoader, checkpoint_path: Path, 
                            device: torch.device) -> None:
    features. _ = train_dataloader.dataset[0]
    run_id = ray_train.get_context().get_trial_id()
    run_checkpoint = checkpoint_path / run_id
    run_checkpoint.mkdir(parents=True)
    encoder = TransformerEncoderModel(parameter["model_dim"], parameter["model_dim"], parameter["enc_feedforward_dim"], 
                                      parameter["encoder_layer"], parameter["enc_transformer_dropout"], parameter["pos_encoder_dropout"], 
                                      parameter["downprojection"], parameter["projection_num_neighbors"])
    decoder = TransformerDecoderModel(encoder.model_dim, encoder.model_dim, parameter["dec_feedforward_dim"], parameter["decoder_layer"], 
                                      encoder.pos_encoder, parameter["dec_transformer_dropout"])
    model = TransformerModel(encoder, decoder)

    optimizer, lr_scheduler = get_optimizer_function_and_learning_rate_scheduler(model, encoder.model_dim, warmup_steps=parameter["warmup_steps"])
    loss_function = get_loss_function()

    _ = train(train_epoch, train_epochs, train_data, validation_data, model, loss_function, optimizer, lr_scheduler, run_checkpoint, device, report_interval=50, tune=True)

In [None]:
parameter_space = {
    "model_dim": tune.choice([15, 9]),
    "batch_size": tune.choice(list(range(batch_size - batch_size_radius, batch_size + batch_size_radius, 4))),
    "warmup_steps": tune.choice(list(range(1000, 2000, 200))),
    "enc_feedforward_dim": tune.choice([32, 64, 128]),
    "dec_feedforward_dim": tune.choice([32, 64, 128]),
    "encoder_layer": tune.choice([1, 2, 3]),
    "decoder_layer": tune.choice([1, 2, 3]),
    "enc_transformer_dropout": tune.uniform(0.1, 0.5),
    "dec_transformer_dropout": tune.uniform(0.1, 0.5),
    "pos_encoder_dropout": tune.uniform(0.1, 0.5),
    "downprojection": tune.choice([True]),
    "projection_num_neighbors": tune.choice(list(range(2, 10)))
}

In [None]:
scheduler = ASHAScheduler(
    metric="loss",
    mode="min",
    max_t=num_epochs
)

In [None]:
search_alg = OptunaSearch(
    metric="loss",
    mode="min"
) 

In [None]:
if ray.is_initialized():
    ray.shutdown()

ray.init(runtime_env={ "py_modules": [utils] })

In [None]:
ray_resources_manager = tune.with_resources(
    trainable=lambda param: parameter_train_encoder_decoder(param, num_epochs, train_dataloader, validation_dataloader, model_path, device),
    # See: https://stackoverflow.com/questions/58967793/what-is-the-way-to-make-tune-run-parallel-trials-across-multiple-gpus
    resources={ "cpu": 3, "gpu": 0.25 }
)

tuner = tune.Tuner(
    ray_resources_manager,
    param_space=parameter_space,
    tune_config=tune.TuneConfig(
        scheduler=scheduler,
        search_alg=search_alg,
         num_samples=num_samples
    )
)

In [None]:
results = tuner.fit()

In [None]:
if ray.is_initialized():
    ray.shutdown()

In [None]:
# Save as csv file
results.get_dataframe().to_csv(model_path / "trail_grid.csv")

In [None]:
best_encoder_decoder_result = results.get_best_result("loss", "min")
best_encoder_decoder_checkpoint = best_encoder_decoder_result.get_best_checkpoint("loss", "min")

best_encoder_decoder_model_dict = torch.load(f"{best_encoder_decoder_checkpoint.path}/checkpoint.pt")

In [None]:
print(f"Best trail by loss value {best_encoder_decoder_result.metrics['loss']}", "\n------")
for i in best_encoder_decoder_result.config:
    print(f"Best trail: {i} value {best_encoder_decoder_result.config[i]}")

In [None]:
encoder = TransformerEncoderModel(best_parallel_result.config["model_dim"], best_parallel_result.config["model_dim"], 
                                  best_parallel_result.config["enc_feedforward_dim"], best_parallel_result.config["encoder_layer"], 
                                  best_parallel_result.config["enc_transformer_dropout"], best_parallel_result.config["pos_encoder_dropout"], 
                                  best_parallel_result.config["downprojection"], best_parallel_result.config["projection_num_neighbors"])
decoder = TransformerDecoderModel(encoder.model_dim, encoder.model_dim, best_decoder_result.config["dec_feedforward_dim"],
                                  best_decoder_result.config["decoder_layer"], best_decoder_result.config["dec_transformer_dropout"],
                                  best_decoder_result.config["pos_encoder_dropout"])
best_encoder_decoder_model = TransformerModel(
    encoder,
    decoder
)
best_encoder_decoder_model.load_state_dict(best_decoder_model_dict)