In [50]:
import torch
from botorch.models import MultiTaskGP
from botorch.fit import fit_gpytorch_model
from botorch.utils import standardize
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch.acquisition.monte_carlo import qUpperConfidenceBound
from botorch.acquisition.objective import GenericMCObjective
import numpy as np
import time
import datetime
from botorch.optim import optimize_acqf
from torch.utils.tensorboard import SummaryWriter
import os

In [2]:
tkwargs = {
    "dtype": torch.double,
    "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
}

### 注意
- metricsの計算およびobjectiveも最小化基準で設計すること

In [66]:
import abc

class IMetric(abc.ABC):
    def __init__(self):
        self.optimal_metric = None

    def parameter_parser(self, x):
        # x: Tensor[D]
        # ここでは簡単のためにパラメーターベクトルからパラメーターを取り出す
        gamma = x[0]
        beta = x[1]
        shadow = x[2]
        return gamma, beta, shadow
    def get_image_from_parameter(self, x):
        # x: Tensor[D]
        # ここは実際にはopenCV or simulatorによってパラメーターから画像を生成する
        gamma, beta, shadow = self.parameter_parser(x)
        
        image = gamma + beta + shadow

        return image
    @abc.abstractmethod
    def calc(self, x):
        # 必ずself.__update_optimal_metricを呼び出すこと
        raise NotImplementedError()
    
    @abc.abstractmethod
    def update_optimal_metric(self, y):
        # yを条件判定した上で最適値を更新する
        raise NotImplementedError()

class Metric0(IMetric):
    
    def calc(self, x):
        image = self.get_image_from_parameter(x)
        y = 3 * image**3 + 2 * image**2 + image + 1
        self.update_optimal_metric(y)
        return y

    def update_optimal_metric(self, y):
        # yを条件判定した上で最適値を更新する
        if self.optimal_metric is None:
            self.optimal_metric = y
        elif self.optimal_metric > y:
            self.optimal_metric = y

class Metric1(IMetric):
    
    def calc(self, x):
        image = self.get_image_from_parameter(x)
        y = - 2 * image**2 + image**2 + image + 1
        self.update_optimal_metric(y)
        return y
    
    def update_optimal_metric(self, y):
        # yを条件判定した上で最適値を更新する
        if self.optimal_metric is None:
            self.optimal_metric = y
        elif self.optimal_metric > y:
            self.optimal_metric = y

    
    

In [67]:
class Metrics():
    def __init__(self):
        self.metrics = {
            "metric0": Metric0(),
            "metric1": Metric1(),
        }
        self.metric_idxs = {
            0: self.metrics["metric0"],
            1: self.metrics["metric1"],
        }
        self.n_metrics = len(self.metrics)
        self.objective = GenericMCObjective(self.aggregator)

    def aggregator(self, samples, X=None):
        # multi-taskなので回帰モデルは複数のyを出力する
        # 獲得関数ではそれを集約し,スカラにする必要がある
        # 今回は全メトリクスの和を考える
        # samples: Tensor[sample_shape?, batch, 取り出す次のサンプル数, タスク数]
        
        return samples.sum(dim=-1)
    
    def get_optimal_metrics(self):
        optimal_metrics = { metric_name: metric.optimal_metric for metric_name, metric in self.metric_idxs.items()}
    
    def __call__(self, X):
        # X: Tensor[N, D+1]
        # return: Tensor[N, 1]
        return torch.stack([self.metric_idxs[x[-1].item()].calc(x[:-1]) for x in X]).unsqueeze(dim=1)
    
metrics = Metrics()

In [68]:
class Parameter():
    def __init__(self, bound, tkwargs):
        # bounds: 2d-list
        self.tkwargs = tkwargs
        self.bound = bound

    def random_sample(self, num):
        return (self.bound[1][0] - self.bound[0][0]) * torch.rand(num, 1, **self.tkwargs) + self.bound[0][0]
    
    def get_bound(self):
        return self.bound



class Parameters:
    def __init__(self, tkwargs):
        self.tkwargs = tkwargs
        self.param_idx = {
            "param0": 0,
            "param1": 1,
            "param2": 2,
        }
        self.params = {
            "param0": Parameter([[-20], [20]], tkwargs),
            "param1": Parameter([[-5], [5]], tkwargs),
            "param2": Parameter([[-5], [40]], tkwargs),
        }
        self.n_params = len(self.params)
        self.bounds = torch.tensor(np.concatenate([param.bound for param in self.params.values()], axis=1), **self.tkwargs)
        self.n_params = len(self.params)

    def random_sample(self, q):
        return torch.stack([param.random_sample(q) for param in self.params.values()], dim=1).squeeze(dim=2)
        
parameters = Parameters(tkwargs)

In [69]:
# 学習ハイパラ
n_iter = 100
n_sample_per_iter = 2

In [71]:
# Tensorboard
t_delta = datetime.timedelta(hours=9)
JST = datetime.timezone(t_delta, 'JST')
now = datetime.datetime.now(JST)
log_dir = f"./runs/{now.strftime('%Y%m%d%H%M%S')}"
gp_writer = SummaryWriter(log_dir=os.path.join(log_dir, "gp"))
rand_writer = SummaryWriter(log_dir=os.path.join(log_dir, "rand"))

# 範囲の保持やランダムサンプリングを担当するオブジェクト
parameters = Parameters(tkwargs)
# メトリクス値の計算や最大値の保持を担当するオブジェクト
gp_metrics = Metrics()
rand_metrics = Metrics()

# 初期サンプルの作成
init_params = parameters.random_sample(2)
# メトリクス（タスクの数）だけ縦に積み、タスク番号を最終列に追加
n_metrics = metrics.n_metrics
train_params = torch.cat([torch.cat((init_params, torch.tensor([[i] for _ in range(init_params.shape[0])], **tkwargs)), dim=1) for i in range(n_metrics)])
print(f"initial params: \n{train_params}")

# 初期学習データ
gp_train_params, gp_train_metrics = train_params, gp_metrics(train_params)
rand_train_params, rand_train_metrics = train_params, rand_metrics(train_params)

# tensorboardへ記録
for metric_name in gp_metrics.metrics.keys():
    gp_writer.add_scalar(f"{metric_name}", gp_metrics.metrics[metric_name].optimal_metric, 0)
    rand_writer.add_scalar(f"{metric_name}", rand_metrics.metrics[metric_name].optimal_metric, 0)

for iter in range(n_iter):

    torch.cuda.empty_cache()
    gp_t0 = time.time()
    # モデルの更新
    gp = MultiTaskGP(gp_train_params, gp_train_metrics, task_feature=-1)
    gp_mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
    fit_gpytorch_model(gp_mll, options={"maxiter": 3000, "lr": 0.01, "disp": False})
    # 獲得関数の計算
    qUCB = qUpperConfidenceBound(gp, beta=0.1, objective=metrics.objective)
    # 獲得関数の最適化
    gp_candidates, _ = optimize_acqf(
        acq_function=qUCB,
        bounds=parameters.bounds,
        q=n_sample_per_iter,
        num_restarts=10,
        raw_samples=128,  # used for intialization heuristic
        options={"batch_limit": 1, "maxiter": 200, "init_batch_limit": 1},
    )
    gp_t1 = time.time()
    # 得られた探索点にタスク番号を追加
    gp_candidates = torch.cat([torch.cat((gp_candidates, torch.tensor([[i] for _ in range(gp_candidates.shape[0])], **tkwargs)), dim=1) for i in range(n_metrics)])
    gp_candidates_metrics = gp_metrics(gp_candidates)
    # 学習データの更新
    gp_train_params = torch.cat((gp_train_params, gp_candidates), dim=0)
    gp_train_metrics = torch.cat((gp_train_metrics, gp_candidates_metrics), dim=0)

    rand_t0 = time.time()
    rand_candidates = parameters.random_sample(q=n_sample_per_iter)
    # 得られた探索点にタスク番号を追加
    rand_candidates = torch.cat([torch.cat((rand_candidates, torch.tensor([[i] for _ in range(rand_candidates.shape[0])], **tkwargs)), dim=1) for i in range(n_metrics)])
    rand_candidates_metrics = rand_metrics(rand_candidates)
    # 学習データの更新(学習しないけど...)
    rand_train_params = torch.cat((rand_train_params, rand_candidates), dim=0)
    rand_train_metrics = torch.cat((rand_train_metrics, rand_candidates_metrics), dim=0)
    rand_t1 = time.time()

    # tensorboardへ記録
    for metric_name in gp_metrics.metrics.keys():
        gp_writer.add_scalar(f"{metric_name}", gp_metrics.metrics[metric_name].optimal_metric, iter+1)
        rand_writer.add_scalar(f"{metric_name}", rand_metrics.metrics[metric_name].optimal_metric, iter+1)

        

    print(f"gp_time: {gp_t1 - gp_t0}")
    print(f"rand_time: {rand_t1 - rand_t0}")

initial params: 
tensor([[ 2.8715,  0.1486, 15.9543,  0.0000],
        [ 3.0004,  4.7735, 31.0274,  0.0000],
        [ 2.8715,  0.1486, 15.9543,  1.0000],
        [ 3.0004,  4.7735, 31.0274,  1.0000]], device='cuda:0',
       dtype=torch.float64)




gp_time: 6.022202253341675
rand_time: 0.0008833408355712891




gp_time: 3.470698118209839
rand_time: 0.0010285377502441406




gp_time: 3.86625599861145
rand_time: 0.0009217262268066406




gp_time: 3.873175859451294
rand_time: 0.0009446144104003906




gp_time: 1.685056209564209
rand_time: 0.0009958744049072266




gp_time: 4.04680061340332
rand_time: 0.0009083747863769531




gp_time: 5.571903467178345
rand_time: 0.0010149478912353516




gp_time: 8.189114332199097
rand_time: 0.0009126663208007812




gp_time: 5.978463172912598
rand_time: 0.0008692741394042969




gp_time: 3.666095495223999
rand_time: 0.0008902549743652344




gp_time: 7.484594345092773
rand_time: 0.0009024143218994141




gp_time: 7.105463743209839
rand_time: 0.0009524822235107422




gp_time: 6.184694528579712
rand_time: 0.0010266304016113281




gp_time: 6.660703897476196
rand_time: 0.0010607242584228516




gp_time: 6.063668727874756
rand_time: 0.001268625259399414




gp_time: 2.8437697887420654
rand_time: 0.0009076595306396484




gp_time: 2.765984058380127
rand_time: 0.0009024143218994141




gp_time: 2.3102657794952393
rand_time: 0.0010495185852050781




gp_time: 3.392512798309326
rand_time: 0.0012392997741699219




gp_time: 3.0049386024475098
rand_time: 0.0010111331939697266




gp_time: 3.578263998031616
rand_time: 0.0009446144104003906




gp_time: 3.8961870670318604
rand_time: 0.0008995532989501953




gp_time: 2.6765811443328857
rand_time: 0.0009911060333251953




gp_time: 4.414525747299194
rand_time: 0.0010094642639160156




gp_time: 4.280689239501953
rand_time: 0.0008509159088134766




gp_time: 3.747098922729492
rand_time: 0.000990152359008789




gp_time: 4.398438453674316
rand_time: 0.0008761882781982422


Trying again with a new set of initial conditions.


gp_time: 5.6330835819244385
rand_time: 0.0010120868682861328




gp_time: 2.113577127456665
rand_time: 0.0009527206420898438




gp_time: 2.2476706504821777
rand_time: 0.0009047985076904297




gp_time: 2.554039716720581
rand_time: 0.0009026527404785156




gp_time: 4.7014546394348145
rand_time: 0.000934600830078125




gp_time: 1.2027416229248047
rand_time: 0.0009567737579345703




gp_time: 2.7998030185699463
rand_time: 0.0012333393096923828




gp_time: 2.4035983085632324
rand_time: 0.0012650489807128906




gp_time: 2.235342264175415
rand_time: 0.0009334087371826172




gp_time: 1.4365370273590088
rand_time: 0.0009047985076904297




gp_time: 2.526395320892334
rand_time: 0.0009222030639648438




gp_time: 8.955311059951782
rand_time: 0.0009121894836425781




gp_time: 1.5550758838653564
rand_time: 0.0009000301361083984




gp_time: 3.7583048343658447
rand_time: 0.0009925365447998047




gp_time: 6.129531383514404
rand_time: 0.0009341239929199219




gp_time: 2.338500738143921
rand_time: 0.0011990070343017578




gp_time: 5.954014539718628
rand_time: 0.0008924007415771484




gp_time: 5.624032258987427
rand_time: 0.0009107589721679688




gp_time: 1.5184617042541504
rand_time: 0.0010025501251220703




gp_time: 2.375192880630493
rand_time: 0.0008921623229980469




gp_time: 1.379258155822754
rand_time: 0.0008988380432128906




gp_time: 2.697449207305908
rand_time: 0.0008976459503173828




gp_time: 1.955385446548462
rand_time: 0.0008904933929443359




gp_time: 3.0568912029266357
rand_time: 0.0008821487426757812




gp_time: 2.6737778186798096
rand_time: 0.0008909702301025391




gp_time: 2.4109859466552734
rand_time: 0.0008656978607177734




gp_time: 2.8125975131988525
rand_time: 0.0009033679962158203




gp_time: 4.199998617172241
rand_time: 0.0009140968322753906




gp_time: 2.5726401805877686
rand_time: 0.0010647773742675781




gp_time: 5.09293532371521
rand_time: 0.0008990764617919922




gp_time: 3.5034351348876953
rand_time: 0.0008819103240966797




gp_time: 6.899758815765381
rand_time: 0.0008847713470458984




gp_time: 2.249345064163208
rand_time: 0.0008840560913085938




gp_time: 1.7020618915557861
rand_time: 0.0008764266967773438




gp_time: 2.1658084392547607
rand_time: 0.0008540153503417969




gp_time: 2.4633047580718994
rand_time: 0.0008733272552490234




gp_time: 3.3499603271484375
rand_time: 0.0009002685546875




gp_time: 3.8607258796691895
rand_time: 0.0009644031524658203


Trying again with a new set of initial conditions.


gp_time: 4.308732271194458
rand_time: 0.0008935928344726562




NotPSDError: Matrix not positive definite after repeatedly adding jitter up to 1.0e-06.

## 気づき
- 一応GPのほうがよりメトリクスを最小化できている
- マルチタスクのときは行列サイズが大きくなってしまうため、行列の正定値性に関して問題が起こりやすそう
- このことから、SingleTaskGPリストのほうが良さそう