In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import ray
from ray.tune.search.optuna import OptunaSearch
from ray import tune
from ray.tune.schedulers import ASHAScheduler
import optuna

import numpy as np
import math

import os
import datetime

from matplotlib import pyplot as plt

from sklearn import metrics
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

import pickle

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
"cuda" if torch.cuda.is_available() else "cpu"

'cuda'

In [2]:
class CustomDataset(Dataset):
  def __init__(self, x, y, target_transform = False):
    self.labels = y
    self.spectrum = x
    self.target_transform = target_transform

  def __len__(self):
    return len(self.labels)

  def __getitem__(self, idx):
    spectra = self.spectrum[idx]
    label = self.labels[idx]

    if self.target_transform != False:
      label = self.target_transform(label)
    else:
      pass

    return torch.tensor(spectra), torch.tensor(label)

In [3]:
class customLoss(nn.Module):
  def __init__(self, weight = [1, 1, 1], punish_distance = False, punish_Lya = False):
    super(customLoss, self).__init__()
    self.weight = weight
    self.punish_distance = punish_distance
    self.punish_Lya = punish_Lya

  def forward(self, output, target):
    weight = self.weight
    target = target.permute(1, 0, 2)

    if self.punish_distance == True:
      punish = (torch.abs(torch.argmax(output[0], dim = 1) - torch.argmax(target[0], dim = 1)) * torch.sum(target[0], dim = 1) + (output[0] > 0.1).sum(dim = 1) * (1 - torch.sum(target[0], dim = 1)))
      punish = (punish / max(torch.sum(punish).item(), 1) * 64).unsqueeze(dim = 1).tile(1, 78)
      BCE = nn.BCELoss(reduction = "mean", weight = punish)
    elif self.punish_Lya != False:
      punish = torch.where(torch.sum(target[0], dim = -1) > 0, self.punish_Lya, 1)
      punish = ((punish / torch.sum(punish).item()) * 64).unsqueeze(dim = -1).tile(1, 78)
      BCE = nn.BCELoss(reduction = "mean", weight = punish)
    else:
      BCE = nn.BCELoss(reduction = "mean")

    MSE = nn.MSELoss(reduction = "mean")
    loss_1 = BCE(output[0], target[0])
    loss_2 = BCE(output[1], target[1])
    loss_3 = MSE(output[2], target[2])

    loss_total = weight[0] * loss_1 + weight[1] * loss_2 + weight[2] * loss_3

    return loss_total / sum(weight), loss_1, loss_2, loss_3

In [4]:
class ConvNet(nn.Module):
    def __init__(self, kernel_size1_1, kernel_size3_1, kernel_size3_2, kernel_size4_1, kernel_size4_2, ratio, out_channel7, pooling, fc_width1, fc_width2, fc_width3, rescale_1, rescale_2, rescale_3, thres, factor):
        self.rescale_1 = rescale_1
        self.rescale_2 = rescale_2
        self.rescale_3 = rescale_3
        self.thres = thres

        super(ConvNet, self).__init__()

        if round(kernel_size1_1 * ratio) % 2 == 0:
            kernel_size1_2 = round(kernel_size1_1 * ratio) + 1
        else:
            kernel_size1_2 = round(kernel_size1_1 * ratio)

        if round(kernel_size1_1 * ratio ** 2) % 2 == 0:
            kernel_size1_3 = round(kernel_size1_1 * ratio ** 2) + 1
        else:
            kernel_size1_3 = round(kernel_size1_1 * ratio ** 2)
            
        if round(kernel_size1_1 * ratio ** 3) % 2 == 0:
            kernel_size1_4 = round(kernel_size1_1 * ratio ** 3) + 1
        else:
            kernel_size1_4 = round(kernel_size1_1 * ratio ** 3)

        stride = 1
        dilation = 1
        out_channel = int(10 * factor)
        out_channel2 = int(10 * factor)
        out_channel3 = int(20 * factor)
        out_channel4 = int(20 * factor)
        out_channel5 = int(40 * factor)
        out_channel6 = int(80 * factor)
        out_channel7 = out_channel7

        self.layer1_1 = nn.Sequential(
            nn.Conv1d(1, out_channel, kernel_size = kernel_size1_1, stride = stride, padding = kernel_size1_1 // 2, dilation = dilation),
            nn.BatchNorm1d(out_channel),
            nn.ReLU(),
            )
        self.layer1_2 = nn.Sequential(
            nn.Conv1d(out_channel, out_channel2, kernel_size = kernel_size1_2, stride = stride, padding = kernel_size1_2 // 2, dilation = dilation),
            nn.BatchNorm1d(out_channel2),
            nn.ReLU(),
            )
        self.layer1_3 = nn.Sequential(
            nn.Conv1d(out_channel2, out_channel3, kernel_size = kernel_size1_3, stride = stride, padding = kernel_size1_3 // 2, dilation = dilation),
            nn.BatchNorm1d(out_channel3),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size = pooling, stride = pooling - 1, padding = pooling // 2),
            )
        self.layer1_4 = nn.Sequential(
            nn.Conv1d(out_channel3, out_channel4, kernel_size = kernel_size1_4, stride = stride, padding = kernel_size1_4 // 2, dilation = dilation),
            nn.BatchNorm1d(out_channel4),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size = pooling, stride = pooling - 1, padding = pooling // 2),
            )
        self.layer1_5 = nn.Sequential(
            nn.Conv1d(out_channel4, out_channel5, kernel_size = 1, stride = stride, padding = 0, dilation = dilation),
            nn.BatchNorm1d(out_channel5),
            nn.ReLU(),
            )
        self.layer1_6 = nn.Sequential(
            nn.Conv1d(out_channel5, out_channel6, kernel_size = 1, stride = stride, padding = 0, dilation = dilation),
            nn.BatchNorm1d(out_channel6),
            nn.ReLU(),
            )
        self.layer1_7 = nn.Sequential(
            nn.Conv1d(out_channel6, out_channel7, kernel_size = 1, stride = stride, padding = 0, dilation = dilation),
            )

        self.layer2 = nn.Sequential(
            nn.Linear(610 * out_channel7, fc_width1),
            nn.BatchNorm1d(fc_width1),
            nn.ReLU(),
            nn.Linear(fc_width1, fc_width1),
            nn.BatchNorm1d(fc_width1),
            nn.ReLU(),
            nn.Linear(fc_width1, fc_width1),
            nn.BatchNorm1d(fc_width1),
            nn.ReLU(),
            nn.Linear(fc_width1, fc_width1),
            nn.BatchNorm1d(fc_width1),
            nn.ReLU(),
            nn.Linear(fc_width1, 78),
            nn.Sigmoid()
            )

        self.layer3_1 = nn.Sequential(
            nn.Conv1d(1, 20, kernel_size = kernel_size3_1, stride = stride, padding = kernel_size3_1 // 2, dilation = dilation),
            nn.BatchNorm1d(20),
            nn.ReLU(),
            )
        self.layer3_2 = nn.Sequential(
            nn.Conv1d(20, 20, kernel_size = kernel_size3_2, stride = stride, padding = kernel_size3_2 // 2, dilation = dilation),
            )

        self.layer4_1 = nn.Sequential(
            nn.Conv1d(1, 20, kernel_size = kernel_size4_1, stride = stride, padding = kernel_size4_1 // 2, dilation = dilation),
            nn.BatchNorm1d(20),
            nn.ReLU(),
            )
        self.layer4_2 = nn.Sequential(
            nn.Conv1d(20, 20, kernel_size = kernel_size4_2, stride = stride, padding = kernel_size4_2 // 2, dilation = dilation),
            )

        self.layer5 = nn.Sequential(
            nn.Linear(1900, fc_width2),
            nn.BatchNorm1d(fc_width2),
            nn.ReLU(),
            nn.Linear(fc_width2, fc_width2),
            nn.BatchNorm1d(fc_width2),
            nn.ReLU(),
            nn.Linear(fc_width2, fc_width2),
            nn.BatchNorm1d(fc_width2),
            nn.ReLU(),
            nn.Linear(fc_width2, fc_width2),
            nn.BatchNorm1d(fc_width2),
            nn.ReLU(),
            nn.Linear(fc_width2, 1),
            nn.Sigmoid()
            )

        self.layer6 = nn.Sequential(
            nn.Linear(1900, fc_width3),
            nn.BatchNorm1d(fc_width3),
            nn.ReLU(),
            nn.Linear(fc_width3, fc_width3),
            nn.BatchNorm1d(fc_width3),
            nn.ReLU(),
            nn.Linear(fc_width3, fc_width3),
            nn.BatchNorm1d(fc_width3),
            nn.ReLU(),
            nn.Linear(fc_width3, fc_width3),
            nn.BatchNorm1d(fc_width3),
            nn.ReLU(),
            nn.Linear(fc_width3, fc_width3),
            nn.BatchNorm1d(fc_width3),
            nn.ReLU(),
            nn.Linear(fc_width3, 1),
            )

    def forward(self, x):
        if self.rescale_1 == "mean":
            x_scale_1 = x - torch.mean(x, dim = -1, keepdim = True)
        elif self.rescale_1 == "zscore":
            x_scale_1 = (x - torch.mean(x, dim = -1, keepdim = True)) / torch.std(x, dim = -1, keepdim = True)
        else:
            x_scale_1 = torch.clone(x)

        x1_1 = self.layer1_1(x_scale_1)
        x1_2 = self.layer1_2(x1_1)
        x1_3 = self.layer1_3(x1_2)
        x1_4 = self.layer1_4(x1_3)
        x1_5 = self.layer1_5(x1_4)
        x1_6 = self.layer1_6(x1_5)
        x1_7 = self.layer1_7(x1_6)
        x1_flatten = x1_7.view(x1_7.size(0), -1)
        x2 = self.layer2(x1_flatten)
        x2_mask = torch.zeros_like(x2, device = device)
        x2_mask[x2 > self.thres] = 1

        crop_index1 = torch.round(torch.argmax(x2, dim = -1) * 31.25 + 15.625).unsqueeze(1) + torch.arange(-47, 48, device = device) + 48
        crop_index1 = crop_index1.type(torch.long)
        x = x.squeeze(dim = 1)
        x_extended = torch.cat((x, torch.full_like(x[:, :48], torch.median(x))), dim = -1)
        x_extended = torch.cat((torch.full_like(x[:, :48], torch.median(x)), x_extended), dim = 1)
        x_cropped = x_extended[torch.arange(x_extended.size(0)).unsqueeze(1), crop_index1]
        x_cropped = x_cropped.unsqueeze(dim = 1)
        x = x.unsqueeze(dim = 1)

        if self.rescale_2 == "mean":
            x_scale_2 = x_cropped - torch.mean(x_cropped, dim = -1, keepdim = True)
        elif self.rescale_2 == "zscore":
            x_scale_2 = (x_cropped - torch.mean(x_cropped, dim = -1, keepdim = True)) / torch.std(x_cropped, dim = -1, keepdim = True)
        else:
            x_scale_2 = torch.clone(x_cropped)

        x3_1 = self.layer3_1(x_scale_2)
        x3_2 = self.layer3_2(x3_1)
        x3_flatten = x3_2.view(x3_2.size(0), -1)

        x5 = self.layer5(x3_flatten)
        x5 = x2_mask * x5

        if self.rescale_3 == "mean":
            x_scale_3 = x_cropped - torch.mean(x_cropped, dim = -1, keepdim = True)
        elif self.rescale_3 == "zscore":
            x_scale_3 = (x_cropped - torch.mean(x_cropped, dim = -1, keepdim = True)) / torch.std(x_cropped, dim = -1, keepdim = True)
        else:
            x_scale_3 = torch.clone(x_cropped)

        x4_1 = self.layer4_1(x_scale_3)
        x4_2 = self.layer4_2(x4_1)
        x4_flatten = x4_2.view(x4_2.size(0), -1)

        x6 = self.layer6(x4_flatten)
        x6 = x2_mask * x6

        x2 = torch.unsqueeze(x2, dim = 0)
        x5 = torch.unsqueeze(x5, dim = 0)
        x6 = torch.unsqueeze(x6, dim = 0)

        return (torch.cat((x2, x5, x6), 0), x_scale_1, x_cropped, x_scale_2, x_scale_3)

In [14]:
spectra_info_test = []
with open("/pscratch/sd/j/juikuan/DESI_LAE_dataset/train_spectra/fuji_pre_b.pkl", "rb") as fh:
    spectra = pickle.load(fh)
    spectra_info_test += spectra
    fuji = np.array([np.array(i["FLUX"]).reshape(1, -1) for i in spectra]) * np.array([np.array(i["IVAR"]).reshape(1, -1) ** (1 / 2) for i in spectra])
    
with open("/pscratch/sd/j/juikuan/DESI_LAE_dataset/train_label/fuji_pre_25.pkl", "rb") as fh:
    fuji_label = pickle.load(fh)
    
spectra_info_train = []
with open("/pscratch/sd/j/juikuan/DESI_LAE_dataset/train_spectra/iron_pre_b.pkl", "rb") as fh:
    spectra = pickle.load(fh)
    spectra_info_train += spectra
    iron = np.array([np.array(i["FLUX"]).reshape(1, -1) for i in spectra]) * np.array([np.array(i["IVAR"]).reshape(1, -1) ** (1 / 2) for i in spectra])
    
with open("/pscratch/sd/j/juikuan/DESI_LAE_dataset/train_label/iron_pre_25.pkl", "rb") as fh:
    iron_label = pickle.load(fh)
    
np.random.seed(3)
np.random.shuffle(spectra_info)
np.random.seed(3)
np.random.shuffle(spectra_info_test)
np.random.seed(3)
np.random.shuffle(fuji)
np.random.seed(3)
np.random.shuffle(fuji_label)
np.random.seed(3)
np.random.shuffle(iron)
np.random.seed(3)
np.random.shuffle(iron_label)

x_train = iron
y_train = iron_label

x_test = fuji
y_test = fuji_label

x_train_id = ray.put(x_train)
y_train_id = ray.put(y_train)
x_test_id = ray.put(x_test)
y_test_id = ray.put(y_test)

In [6]:
def train(epoch, model, optimizer, criterion, train_loader, thres):
    train_loss_iter = 0
    correct_1 = 0
    correct_2 = 0
    correct_3 = 0
    output1_record = torch.tensor([[], []], device = device)
    output2_record = torch.tensor([[], []], device = device)
    output3_record = torch.tensor([[], []], device = device)
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        model = model.to(device).float()
        results = model(data.float())
        output = results[0]

        loss, loss_1, loss_2, loss_3 = criterion(output.float(), target.float())
        loss.backward()
        optimizer.step()
        train_loss_iter += loss

        output = output.detach()

        output1 = output[0].argmax(dim = -1)
        output1[output[0].max(dim = -1)[0] < thres] = 0
        target1 = target.permute(1, 0, 2)[0].argmax(dim = -1)
        difference1 = (output1 - target1) ** 2
        correct_1 += len(difference1[difference1 < 1])
        output1_record = torch.cat((output1_record, torch.cat((output1.unsqueeze(0), target1.unsqueeze(0)), dim = 0)), dim = -1)

        mask1 = torch.where(output[0] > 0.7, 1, 0)

        output2 = torch.sum(output[1] * mask1, dim = -1)
        target2 = torch.sum((target.permute(1, 0, 2)[1]), dim = -1)
        difference2 = (output2 - target2) ** 2
        correct_2 += len(difference2[difference2 < 0.0025])
        output2_record = torch.cat((output2_record, torch.cat((output2.unsqueeze(0), target2.unsqueeze(0)), dim = 0)), dim = -1)

        output3 = torch.sum(output[2] * mask1, dim = -1)
        target3 = torch.sum((target.permute(1, 0, 2)[2]), dim = -1)
        difference3 = (output3 - target3) ** 2
        correct_3 += len(difference3[difference3 < 9])
        output3_record = torch.cat((output3_record, torch.cat((output3.unsqueeze(0), target3.unsqueeze(0)), dim = 0)), dim = -1)

    train_loss_iter /= len(train_loader.dataset)

    return model, train_loss_iter, output1_record, output2_record, output3_record

In [7]:
def test(epoch, model, criterion, test_loader, ray_tune, thres):
    test_loss_iter = 0
    test_loss1_iter = 0
    test_loss2_iter = 0
    test_loss3_iter = 0
    correct_1 = 0
    correct_2 = 0
    correct_3 = 0
    output1_record = torch.tensor([[], []], device = device)
    output2_record = torch.tensor([[], []], device = device)
    output3_record = torch.tensor([[], []], device = device)

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            model = model.to(device).float()
            results = model(data.float())
            output = results[0]

            loss, loss_1, loss_2, loss_3 = criterion(output.float(), target.float())
            test_loss_iter += loss
            test_loss1_iter += loss_1
            test_loss2_iter += loss_2
            test_loss3_iter += loss_3

            output1 = output[0].argmax(dim = -1)
            output1[output[0].max(dim = -1)[0] < thres] = 0
            target1 = target.permute(1, 0, 2)[0].argmax(dim = -1)
            difference1 = (output1 - target1) ** 2
            correct_1 += len(difference1[difference1 < 1])
            output1_record = torch.cat((output1_record, torch.cat((output1.unsqueeze(0), target1.unsqueeze(0)), dim = 0)), dim = -1)

            mask1 = torch.where(output[0] > thres, 1, 0)

            output2 = torch.sum(output[1] * mask1, dim = -1)
            target2 = torch.sum((target.permute(1, 0, 2)[1]), dim = -1)
            difference2 = (output2 - target2) ** 2
            correct_2 += len(difference2[difference2 < 0.0025])
            output2_record = torch.cat((output2_record, torch.cat((output2.unsqueeze(0), target2.unsqueeze(0)), dim = 0)), dim = -1)

            output3 = torch.sum(output[2] * mask1, dim = -1)
            target3 = torch.sum((target.permute(1, 0, 2)[2]), dim = -1)
            difference3 = (output3 - target3) ** 2
            correct_3 += len(difference3[difference3 < 9])
            output3_record = torch.cat((output3_record, torch.cat((output3.unsqueeze(0), target3.unsqueeze(0)), dim = 0)), dim = -1)

    data_count = len(test_loader.dataset)

    test_loss_iter /= data_count
    test_loss1_iter /= data_count
    test_loss2_iter /= data_count
    test_loss3_iter /= data_count

    if not ray_tune:
        print(f'平均損失: {test_loss_iter:.6f}, 1st loss: {test_loss1_iter:.6f}({correct_1 / data_count * 100:.1f}), 2nd loss: {test_loss2_iter:.6f}({correct_2 / data_count * 100:.1f}), 3rd loss: {test_loss3_iter:.6f}({correct_3 / data_count * 100:.1f})')

    return model, (test_loss_iter, test_loss1_iter, test_loss2_iter, test_loss3_iter), (correct_1 / data_count, correct_2 / data_count, correct_3 / data_count), output1_record, output2_record, output3_record

In [8]:
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.kaiming_uniform_(m.weight, nonlinearity = 'relu', mode = 'fan_in')

In [9]:
def build_model(config, x_train_id = x_train_id, y_train_id = y_train_id, x_test_id = x_test_id, y_test_id = y_test_id, ray_tune = True, model_size = "M"):
    
    x = ray.get(x_train_id)
    
    length = round(config["ratio"] * x.shape[0])
    if length % int(config["batch"]) == 1:
        random_indices = np.random.randint(x.shape[0], size = length + 1)
    else:
        random_indices = np.random.randint(x.shape[0], size = length)
    
    x = x[random_indices]
    y = ray.get(y_train_id)
    y = y[random_indices]
    
    train_dataset = CustomDataset(x = x, y = y)
    test_dataset = CustomDataset(x = ray.get(x_test_id), y = ray.get(y_test_id))

    train_loader = DataLoader(train_dataset, batch_size = config["batch"], num_workers = 128, pin_memory = True)
    test_loader = DataLoader(test_dataset, batch_size = config["batch"], num_workers = 128, pin_memory = True)

    model = ConvNet(kernel_size1_1 = config["ks1_1"], kernel_size4_2 = config["ks4_2"],
                  kernel_size3_1 = config["ks3_1"], thres = config["thres"],
                  kernel_size3_2 = config["ks3_2"], kernel_size4_1 = config["ks4_1"],
                  ratio = config["ks_ratio"], out_channel7 = config["out_ch7"],
                  pooling = config["ps"], fc_width1 = config["w1"],
                  fc_width2 = config["w2"], fc_width3 = config["w3"],
                  rescale_1 = config["rescale1"], rescale_2 = config["rescale2"],
                  rescale_3 = config["rescale3"], factor = config["factor"])

    criterion = customLoss(weight = [config["lw1"], config["lw2"], config["lw3"]])
    optimizer = optim.Adam(model.parameters(), lr = config["lr"])

    epoch = 60

    model.apply(init_weights)

    for i in range(1, epoch + 1):
        model.train()
        model, train_loss_iter, output1_record, output2_record, output3_record = train(epoch = i, model = model, criterion = criterion, optimizer = optimizer, train_loader = train_loader, thres = config["thres"])

        model.eval()
        model, test_loss_iter, test_accuracy_iter, output1_record, output2_record, output3_record = test(epoch = i, model = model, criterion = criterion, test_loader = test_loader, ray_tune = ray_tune, thres = config["thres"])

        ray.train.report({"acc1": test_acc_1[-1], "acc2": test_acc_2[-1], "acc3": test_acc_3[-1]})

In [10]:
def plot_loss_iter(test_loss, train_loss):
    plt.close()
    epoch = [i for i in range(1, len(test_loss) + 1)]
    plt.plot(epoch, test_loss, label = "Test Loss", c = "m")
    plt.plot(epoch, train_loss, label = "Train Loss", c = "y")
    plt.legend()
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    # plt.savefig("/content/drive/MyDrive/loss.jpg")
    plt.show()

def plot_test_accuracy(test_acc1, test_acc2, test_acc3):
    plt.close()
    epoch = [i for i in range(1, len(test_acc1) + 1)]
    plt.plot(epoch, test_acc1, label = "Rough Position", c = "r")
    plt.plot(epoch, test_acc2, label = "Precise Position", c = "m")
    plt.plot(epoch, test_acc3, label = "FWHM", c = "y")
    plt.legend()
    plt.xlabel("Epoch")
    plt.ylabel("Test Accuracy")
    # plt.savefig("/content/drive/MyDrive/loss.jpg")
    plt.show()

In [None]:
search_space = {
    "lr": tune.choice([0.005368]),
    "ks1_1": tune.qrandint(9, 45, 2),
    "ks3_1": tune.choice([3]),
    "ks3_2": tune.choice([3]),
    "ks4_1": tune.choice([3]),
    "ks4_2": tune.choice([3]),
    "batch": tune.choice([8]),
    "ks_ratio": tune.loguniform(0.1, 10),
    "out_ch7": tune.choice([8, 16, 32, 64, 128, 256, 512, 1024]),
    "ps": tune.choice([3]),
    "lw1": tune.choice([1]),
    "lw2": tune.choice([1 / 5]),
    "lw3": tune.choice([1 / 2 * 10 ** -2]),
    "w1": tune.choice([8, 16, 32, 64, 128, 256, 512, 1024]),
    "w2": tune.choice([8, 16, 32, 64, 128, 256, 512, 1024]),
    "w3": tune.choice([8, 16, 32, 64, 128, 256, 512, 1024]),
    "rescale1": tune.choice(["mean"]),
    "rescale2": tune.choice(["mean"]),
    "rescale3": tune.choice(["mean"]),
    "LAE_w": tune.choice([False]),
    "ratio": tune.choice([1]),
    "thres": tune.uniform(0.5, 1),
    "factor": tune.loguniform(0.1, 10)
}

optuna_search = OptunaSearch(
    metric = "acc1",
)

scheduler = ASHAScheduler(
    metric = "acc1",
    max_t = 100,
    grace_period=10,
    reduction_factor = 2
)

trainable_with_gpu = tune.with_resources(build_model, {"gpu": 1})
analysis = tune.run(
    trainable_with_gpu,
    search_alg = optuna_search,
    scheduler = scheduler,
    num_samples = 400,
    mode = "max",
    config = search_space,
    reuse_actors = False
)

In [None]:
search_space = {
    "lr": tune.grid_search([0.005368]),
    "ks1_1": tune.grid_search([13]),
    "ks3_1": tune.grid_search([3]),
    "ks3_2": tune.grid_search([3]),
    "ks4_1": tune.grid_search([3]),
    "ks4_2": tune.grid_search([3]),
    "batch": tune.grid_search([8]),
    "ks_ratio": tune.grid_search([1.324924256]),
    "out_ch7": tune.grid_search([512]),
    "ps": tune.grid_search([3]),
    "lw1": tune.grid_search([1]),
    "lw2": tune.grid_search([1 / 5]),
    "lw3": tune.grid_search([1 / 2 * 10 ** -2]),
    "w1": tune.grid_search([32]),
    "w2": tune.grid_search([128]),
    "w3": tune.grid_search([512]),
    "rescale1": tune.grid_search(["mean"]),
    "rescale2": tune.grid_search(["mean"]),
    "rescale3": tune.grid_search(["mean"]),
    "LAE_w": tune.grid_search([False]),
    "ratio": tune.grid_search([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]),
    "thres": tune.grid_search([0.554768217]),
    "factor": tune.grid_search([6.755653316])
}

trainable_with_gpu = tune.with_resources(build_model, {"gpu": 1})
tuner = tune.Tuner(
    trainable_with_gpu,
    param_space = search_space,
    tune_config = tune.TuneConfig(num_samples = 8)
)
results = tuner.fit()