In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from adan_pytorch import Adan
import copy
from utils import preparte_data_loader
from utils import set_parameter_requires_grad
from utils import DEVICE, LOSS_CRITERIA
import time
from transformers import HubertConfig, HubertForSequenceClassification
from utils import ONLY_10_LABELS
import matplotlib.pyplot as plt
from utils import DEVICE, AudioDataset
from torch.utils.data import DataLoader
import optuna
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
batch_size = 64
num_labels = 50#10 if ONLY_10_LABELS else 50
epochs = 10
device=DEVICE
loss_criteria = LOSS_CRITERIA
# train_dataloader = preparte_data_loader(mode='train', batch_size=20)
# val_dataloader = preparte_data_loader(mode='val', batch_size=20)
train_dataset = AudioDataset(kind='train')
val_dataset = AudioDataset(kind='val')
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

dataloaders = {
    'train':train_loader,
    'val':val_loader
}



def optune_optimizer_for_model(trial, optimizer_name):
    
    config = HubertConfig(
    hidden_size=128,
    num_hidden_layers=2,
    num_attention_heads=4,
    intermediate_size=256,
    num_labels=num_labels,
    conv_dim=(512, 512, 512),
    conv_stride=(5, 2, 2),
    conv_kernel=(10, 3, 3)
    )

    hubert_model = HubertForSequenceClassification(config)
    
    model = hubert_model.to(DEVICE)
    lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)  # log=True, will use log scale to interplolate between lr
    #trial.suggest_categorical("optimizer", ["Adam", "SGD","Adan"])
    if optimizer_name == 'Adam':
        optimizer = optim.Adam(model.parameters(), lr=lr)
    elif optimizer_name == 'SGD':
        sgd_momentum = trial.suggest_float("sgd_momentum", 1e-1, 1)
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=sgd_momentum, nesterov=True)
    elif optimizer_name == 'Adan':
        beta1 = trial.suggest_float("beta1", 1e-3, 1e-1)
        beta2 = trial.suggest_float("beta2", 1e-3, 1e-1)
        beta3 = trial.suggest_float("beta3", 1e-3, 1e-1)
        optimizer = Adan(model.parameters(),lr = lr,
            betas = (beta1, beta2, beta3), 
            weight_decay = 0.02         # weight decay 0.02 is optimal per author
        )
        
    
    # out_features = trial.suggest_int("n_units_l{}".format(i), 4, 128)  # number of units will be between 4 and 128
    

    for epoch in range(epochs):
        print('Epoch {}/{}'.format(epoch, epochs - 1))
        print('-' * 10)
        
        model.train()
        
        # Iterate over data.
        for inputs, labels in dataloaders['train']:
            inputs = inputs.to(device)
            labels = labels.to(device)
            with torch.set_grad_enabled(True):
                # Get model outputs and calculate loss
                logits_h = model(inputs).logits.type(torch.cuda.FloatTensor)
                labels = labels.type(torch.cuda.FloatTensor)
                loss = loss_criteria(logits_h.view(-1), labels.view(-1))
                # loss = loss_criteria(outputs, labels)

                _, preds = torch.max(logits_h, 1)
                _, real_labels = torch.max(labels, 1)

                # backward + optimize only if in training phase
                # zero the parameter gradients
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        
        model.eval()
        running_corrects = 0.0
        for inputs, labels in dataloaders['val']:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # forward
                # track history if only in train
                with torch.set_grad_enabled(False):
                    # Get model outputs and calculate loss
                    # outputs = model(inputs)
                    
                    logits_h = model(inputs).logits.type(torch.cuda.FloatTensor)
                    labels = labels.type(torch.cuda.FloatTensor)
                    loss = loss_criteria(logits_h.view(-1), labels.view(-1))
                    # loss = loss_criteria(outputs, labels)

                    _, preds = torch.max(logits_h, 1)
                    _, real_labels = torch.max(labels, 1)
                    
                    # loss = loss_criteria(outputs, labels)
                    # _, preds = torch.max(outputs, 1)

                # statistics
                running_corrects += torch.sum(preds == real_labels.data)

        epoch_acc = running_corrects.double() / len(dataloaders['val'].dataset)

        # report back to Optuna how far it is (epoch-wise) into the trial and how well it is doing (accuracy)
        trial.report(epoch_acc, epoch)  

        # then, Optuna can decide if the trial should be pruned
        # Handle pruning based on the intermediate value.
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

    return epoch_acc

In [9]:
# now we can run the experiment
sampler = optuna.samplers.TPESampler()
configurations = [{'optimizer_name':'Adam',
                   'study_name':'adam_study'},
                  {'optimizer_name':'SGD',
                   'study_name':'sgd_study'},
                  {'optimizer_name':'Adan',
                   'study_name':'adan_study'}
                  ]
studies = {}
for config in configurations:
    
    study = optuna.create_study(study_name=config['study_name'], direction="maximize", sampler=sampler)
    study.optimize(lambda trial: optune_optimizer_for_model(trial, config['optimizer_name']), n_trials=500, timeout=600)
    
    study_csv = f'{config["study_name"]}.csv'
    study.trials_dataframe().to_csv(study_csv)
    studies[config["study_name"]] = study_csv
    
    pruned_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]
    complete_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]

    print("Study statistics: ")
    print("  Number of finished trials: ", len(study.trials))
    print("  Number of pruned trials: ", len(pruned_trials))
    print("  Number of complete trials: ", len(complete_trials))
    print("Best trial:")

    trial = study.best_trial

    print("  Value: ", trial.value)
    print("  Params: ")

    for key, value in trial.params.items():
        print("    {}: {}".format(key, value))

[I 2024-08-07 22:01:05,634] A new study created in memory with name: adam_study


Epoch 0/9
----------
Epoch 1/9
----------
Epoch 2/9
----------
Epoch 3/9
----------
Epoch 4/9
----------
Epoch 5/9
----------
Epoch 6/9
----------
Epoch 7/9
----------
Epoch 8/9
----------
Epoch 9/9
----------


[I 2024-08-07 22:02:11,162] Trial 0 finished with value: 0.1 and parameters: {'lr': 0.023412811673960664}. Best is trial 0 with value: 0.1.


Epoch 0/9
----------
Epoch 1/9
----------
Epoch 2/9
----------
Epoch 3/9
----------
Epoch 4/9
----------
Epoch 5/9
----------
Epoch 6/9
----------
Epoch 7/9
----------
Epoch 8/9
----------
Epoch 9/9
----------


[I 2024-08-07 22:03:07,826] Trial 1 finished with value: 0.1 and parameters: {'lr': 0.03933215121713092}. Best is trial 0 with value: 0.1.


Epoch 0/9
----------
Epoch 1/9
----------
Epoch 2/9
----------
Epoch 3/9
----------
Epoch 4/9
----------
Epoch 5/9
----------
Epoch 6/9
----------
Epoch 7/9
----------
Epoch 8/9
----------
Epoch 9/9
----------


[I 2024-08-07 22:04:05,668] Trial 2 finished with value: 0.1 and parameters: {'lr': 0.017175042243039625}. Best is trial 0 with value: 0.1.


Epoch 0/9
----------
Epoch 1/9
----------
Epoch 2/9
----------
Epoch 3/9
----------
Epoch 4/9
----------
Epoch 5/9
----------
Epoch 6/9
----------
Epoch 7/9
----------
Epoch 8/9
----------
Epoch 9/9
----------


[I 2024-08-07 22:05:05,497] Trial 3 finished with value: 0.1 and parameters: {'lr': 0.003114515485764312}. Best is trial 0 with value: 0.1.


Epoch 0/9
----------
Epoch 1/9
----------
Epoch 2/9
----------
Epoch 3/9
----------
Epoch 4/9
----------
Epoch 5/9
----------
Epoch 6/9
----------
Epoch 7/9
----------
Epoch 8/9
----------
Epoch 9/9
----------


[I 2024-08-07 22:06:05,273] Trial 4 finished with value: 0.1 and parameters: {'lr': 0.041703663579044085}. Best is trial 0 with value: 0.1.


Epoch 0/9
----------
Epoch 1/9
----------
Epoch 2/9
----------
Epoch 3/9
----------
Epoch 4/9
----------
Epoch 5/9
----------
Epoch 6/9
----------
Epoch 7/9
----------
Epoch 8/9
----------
Epoch 9/9
----------


[I 2024-08-07 22:07:03,343] Trial 5 finished with value: 0.42916666666666664 and parameters: {'lr': 0.002174936779900449}. Best is trial 5 with value: 0.42916666666666664.


Epoch 0/9
----------
Epoch 1/9
----------
Epoch 2/9
----------
Epoch 3/9
----------
Epoch 4/9
----------
Epoch 5/9
----------
Epoch 6/9
----------
Epoch 7/9
----------
Epoch 8/9
----------
Epoch 9/9
----------


[I 2024-08-07 22:08:03,042] Trial 6 finished with value: 0.125 and parameters: {'lr': 2.1171025325901415e-05}. Best is trial 5 with value: 0.42916666666666664.


Epoch 0/9
----------


[I 2024-08-07 22:08:09,457] Trial 7 pruned. 


Epoch 0/9
----------


[I 2024-08-07 22:08:15,044] Trial 8 pruned. 


Epoch 0/9
----------
Epoch 1/9
----------
Epoch 2/9
----------
Epoch 3/9
----------
Epoch 4/9
----------
Epoch 5/9
----------
Epoch 6/9
----------
Epoch 7/9
----------
Epoch 8/9
----------
Epoch 9/9
----------


[I 2024-08-07 22:09:15,215] Trial 9 finished with value: 0.4125 and parameters: {'lr': 0.0010588662281037055}. Best is trial 5 with value: 0.42916666666666664.


Epoch 0/9
----------
Epoch 1/9
----------
Epoch 2/9
----------
Epoch 3/9
----------
Epoch 4/9
----------
Epoch 5/9
----------
Epoch 6/9
----------
Epoch 7/9
----------
Epoch 8/9
----------
Epoch 9/9
----------


[I 2024-08-07 22:10:15,413] Trial 10 finished with value: 0.25416666666666665 and parameters: {'lr': 0.0001404159824138537}. Best is trial 5 with value: 0.42916666666666664.


Epoch 0/9
----------
Epoch 1/9
----------
Epoch 2/9
----------
Epoch 3/9
----------
Epoch 4/9
----------
Epoch 5/9
----------
Epoch 6/9
----------
Epoch 7/9
----------
Epoch 8/9
----------
Epoch 9/9
----------


[I 2024-08-07 22:11:15,763] Trial 11 finished with value: 0.5041666666666667 and parameters: {'lr': 0.0012161288828451268}. Best is trial 11 with value: 0.5041666666666667.
[I 2024-08-07 22:11:15,778] A new study created in memory with name: sgd_study


Study statistics: 
  Number of finished trials:  12
  Number of pruned trials:  2
  Number of complete trials:  10
Best trial:
  Value:  0.5041666666666667
  Params: 
    lr: 0.0012161288828451268
Epoch 0/9
----------
Epoch 1/9
----------
Epoch 2/9
----------
Epoch 3/9
----------
Epoch 4/9
----------
Epoch 5/9
----------
Epoch 6/9
----------
Epoch 7/9
----------
Epoch 8/9
----------
Epoch 9/9
----------


[I 2024-08-07 22:12:16,538] Trial 0 finished with value: 0.1 and parameters: {'lr': 0.03495458556963842, 'sgd_momentum': 0.6156705499539906}. Best is trial 0 with value: 0.1.


Epoch 0/9
----------
Epoch 1/9
----------
Epoch 2/9
----------
Epoch 3/9
----------
Epoch 4/9
----------
Epoch 5/9
----------
Epoch 6/9
----------
Epoch 7/9
----------
Epoch 8/9
----------
Epoch 9/9
----------


[I 2024-08-07 22:13:16,881] Trial 1 finished with value: 0.2583333333333333 and parameters: {'lr': 0.0002641417381613565, 'sgd_momentum': 0.6614516736520298}. Best is trial 1 with value: 0.2583333333333333.


Epoch 0/9
----------
Epoch 1/9
----------
Epoch 2/9
----------
Epoch 3/9
----------
Epoch 4/9
----------
Epoch 5/9
----------
Epoch 6/9
----------
Epoch 7/9
----------
Epoch 8/9
----------
Epoch 9/9
----------


[I 2024-08-07 22:14:15,706] Trial 2 finished with value: 0.1 and parameters: {'lr': 0.019511410730804427, 'sgd_momentum': 0.6323725390074426}. Best is trial 1 with value: 0.2583333333333333.


Epoch 0/9
----------
Epoch 1/9
----------
Epoch 2/9
----------
Epoch 3/9
----------
Epoch 4/9
----------
Epoch 5/9
----------
Epoch 6/9
----------
Epoch 7/9
----------
Epoch 8/9
----------
Epoch 9/9
----------


[I 2024-08-07 22:15:15,093] Trial 3 finished with value: 0.15833333333333333 and parameters: {'lr': 1.517033539743099e-05, 'sgd_momentum': 0.8866906267547824}. Best is trial 1 with value: 0.2583333333333333.


Epoch 0/9
----------
Epoch 1/9
----------
Epoch 2/9
----------
Epoch 3/9
----------
Epoch 4/9
----------
Epoch 5/9
----------
Epoch 6/9
----------
Epoch 7/9
----------
Epoch 8/9
----------
Epoch 9/9
----------


[I 2024-08-07 22:16:12,795] Trial 4 finished with value: 0.1 and parameters: {'lr': 0.02767707274795679, 'sgd_momentum': 0.9673570725132417}. Best is trial 1 with value: 0.2583333333333333.


Epoch 0/9
----------
Epoch 1/9
----------
Epoch 2/9
----------
Epoch 3/9
----------
Epoch 4/9
----------
Epoch 5/9
----------
Epoch 6/9
----------
Epoch 7/9
----------
Epoch 8/9
----------
Epoch 9/9
----------


[I 2024-08-07 22:17:13,249] Trial 5 finished with value: 0.325 and parameters: {'lr': 0.0005561224860038818, 'sgd_momentum': 0.7504513858587106}. Best is trial 5 with value: 0.325.


Epoch 0/9
----------


[I 2024-08-07 22:17:19,521] Trial 6 pruned. 


Epoch 0/9
----------
Epoch 1/9
----------
Epoch 2/9
----------
Epoch 3/9
----------
Epoch 4/9
----------
Epoch 5/9
----------
Epoch 6/9
----------
Epoch 7/9
----------
Epoch 8/9
----------
Epoch 9/9
----------


[I 2024-08-07 22:18:18,260] Trial 7 finished with value: 0.1 and parameters: {'lr': 0.005880489968384445, 'sgd_momentum': 0.29103939871526663}. Best is trial 5 with value: 0.325.


Epoch 0/9
----------


[I 2024-08-07 22:18:24,133] Trial 8 pruned. 


Epoch 0/9
----------


[I 2024-08-07 22:18:30,405] Trial 9 pruned. 


Epoch 0/9
----------
Epoch 1/9
----------
Epoch 2/9
----------
Epoch 3/9
----------
Epoch 4/9
----------
Epoch 5/9
----------
Epoch 6/9
----------
Epoch 7/9
----------
Epoch 8/9
----------
Epoch 9/9
----------


[I 2024-08-07 22:19:30,051] Trial 10 finished with value: 0.36666666666666664 and parameters: {'lr': 0.0004399074848905889, 'sgd_momentum': 0.782315314217227}. Best is trial 10 with value: 0.36666666666666664.


Epoch 0/9
----------
Epoch 1/9
----------
Epoch 2/9
----------
Epoch 3/9
----------
Epoch 4/9
----------
Epoch 5/9
----------
Epoch 6/9
----------
Epoch 7/9
----------
Epoch 8/9
----------
Epoch 9/9
----------


[I 2024-08-07 22:20:28,095] Trial 11 finished with value: 0.3625 and parameters: {'lr': 0.0003935572288930598, 'sgd_momentum': 0.7941764192455438}. Best is trial 10 with value: 0.36666666666666664.


Epoch 0/9
----------
Epoch 1/9
----------
Epoch 2/9
----------
Epoch 3/9
----------
Epoch 4/9
----------
Epoch 5/9
----------
Epoch 6/9
----------
Epoch 7/9
----------
Epoch 8/9
----------
Epoch 9/9
----------


[I 2024-08-07 22:21:28,055] Trial 12 finished with value: 0.225 and parameters: {'lr': 0.00012479299479413865, 'sgd_momentum': 0.8097157166614828}. Best is trial 10 with value: 0.36666666666666664.
[I 2024-08-07 22:21:28,065] A new study created in memory with name: adan_study


Study statistics: 
  Number of finished trials:  13
  Number of pruned trials:  3
  Number of complete trials:  10
Best trial:
  Value:  0.36666666666666664
  Params: 
    lr: 0.0004399074848905889
    sgd_momentum: 0.782315314217227
Epoch 0/9
----------
Epoch 1/9
----------
Epoch 2/9
----------
Epoch 3/9
----------
Epoch 4/9
----------
Epoch 5/9
----------
Epoch 6/9
----------
Epoch 7/9
----------
Epoch 8/9
----------
Epoch 9/9
----------


[I 2024-08-07 22:22:31,944] Trial 0 finished with value: 0.1 and parameters: {'lr': 0.014167554462473827, 'beta1': 0.004227344088888265, 'beta2': 0.05128712458850703, 'beta3': 0.07192808989228096}. Best is trial 0 with value: 0.1.


Epoch 0/9
----------
Epoch 1/9
----------
Epoch 2/9
----------
Epoch 3/9
----------
Epoch 4/9
----------
Epoch 5/9
----------
Epoch 6/9
----------
Epoch 7/9
----------
Epoch 8/9
----------
Epoch 9/9
----------


[I 2024-08-07 22:23:36,766] Trial 1 finished with value: 0.25 and parameters: {'lr': 0.007202794853936493, 'beta1': 0.09905421431054498, 'beta2': 0.010663097037349106, 'beta3': 0.0739774345467093}. Best is trial 1 with value: 0.25.


Epoch 0/9
----------
Epoch 1/9
----------
Epoch 2/9
----------
Epoch 3/9
----------
Epoch 4/9
----------
Epoch 5/9
----------
Epoch 6/9
----------
Epoch 7/9
----------
Epoch 8/9
----------
Epoch 9/9
----------


[I 2024-08-07 22:24:39,965] Trial 2 finished with value: 0.42083333333333334 and parameters: {'lr': 0.004981408060246446, 'beta1': 0.0398841519885873, 'beta2': 0.040172319449575195, 'beta3': 0.030669382770363105}. Best is trial 2 with value: 0.42083333333333334.


Epoch 0/9
----------
Epoch 1/9
----------
Epoch 2/9
----------
Epoch 3/9
----------
Epoch 4/9
----------
Epoch 5/9
----------
Epoch 6/9
----------
Epoch 7/9
----------
Epoch 8/9
----------
Epoch 9/9
----------


[I 2024-08-07 22:25:43,378] Trial 3 finished with value: 0.1 and parameters: {'lr': 0.05240921416962004, 'beta1': 0.05476766882557856, 'beta2': 0.060359871836554974, 'beta3': 0.04916153513077381}. Best is trial 2 with value: 0.42083333333333334.


Epoch 0/9
----------
Epoch 1/9
----------
Epoch 2/9
----------
Epoch 3/9
----------
Epoch 4/9
----------
Epoch 5/9
----------
Epoch 6/9
----------
Epoch 7/9
----------
Epoch 8/9
----------
Epoch 9/9
----------


[I 2024-08-07 22:26:47,633] Trial 4 finished with value: 0.5208333333333334 and parameters: {'lr': 0.0020022613730624438, 'beta1': 0.046914599109421654, 'beta2': 0.019381017240822685, 'beta3': 0.010655285691959324}. Best is trial 4 with value: 0.5208333333333334.


Epoch 0/9
----------
Epoch 1/9
----------


[I 2024-08-07 22:27:01,007] Trial 5 pruned. 


Epoch 0/9
----------
Epoch 1/9
----------
Epoch 2/9
----------
Epoch 3/9
----------
Epoch 4/9
----------
Epoch 5/9
----------
Epoch 6/9
----------
Epoch 7/9
----------
Epoch 8/9
----------
Epoch 9/9
----------


[I 2024-08-07 22:28:08,384] Trial 6 finished with value: 0.4041666666666667 and parameters: {'lr': 0.0006957036586323713, 'beta1': 0.08457960498766, 'beta2': 0.060780462660088, 'beta3': 0.029694573943585167}. Best is trial 4 with value: 0.5208333333333334.


Epoch 0/9
----------
Epoch 1/9
----------
Epoch 2/9
----------
Epoch 3/9
----------
Epoch 4/9
----------
Epoch 5/9
----------
Epoch 6/9
----------
Epoch 7/9
----------
Epoch 8/9
----------
Epoch 9/9
----------


[I 2024-08-07 22:29:12,075] Trial 7 finished with value: 0.4875 and parameters: {'lr': 0.0012964346688131474, 'beta1': 0.030499766021383105, 'beta2': 0.007088684308307633, 'beta3': 0.0072920870617401035}. Best is trial 4 with value: 0.5208333333333334.


Epoch 0/9
----------


[I 2024-08-07 22:29:18,890] Trial 8 pruned. 


Epoch 0/9
----------


[I 2024-08-07 22:29:24,856] Trial 9 pruned. 


Epoch 0/9
----------


[I 2024-08-07 22:29:31,695] Trial 10 pruned. 


Epoch 0/9
----------
Epoch 1/9
----------
Epoch 2/9
----------
Epoch 3/9
----------
Epoch 4/9
----------


[I 2024-08-07 22:30:04,716] Trial 11 pruned. 


Epoch 0/9
----------
Epoch 1/9
----------
Epoch 2/9
----------
Epoch 3/9
----------
Epoch 4/9
----------


[I 2024-08-07 22:30:35,657] Trial 12 pruned. 


Epoch 0/9
----------


[I 2024-08-07 22:30:42,472] Trial 13 pruned. 


Epoch 0/9
----------


[I 2024-08-07 22:30:48,263] Trial 14 pruned. 


Epoch 0/9
----------
Epoch 1/9
----------
Epoch 2/9
----------
Epoch 3/9
----------
Epoch 4/9
----------


[I 2024-08-07 22:31:18,396] Trial 15 pruned. 


Epoch 0/9
----------


[I 2024-08-07 22:31:25,157] Trial 16 pruned. 


Epoch 0/9
----------
Epoch 1/9
----------
Epoch 2/9
----------
Epoch 3/9
----------
Epoch 4/9
----------


[I 2024-08-07 22:31:58,267] Trial 17 pruned. 


Study statistics: 
  Number of finished trials:  18
  Number of pruned trials:  11
  Number of complete trials:  7
Best trial:
  Value:  0.5208333333333334
  Params: 
    lr: 0.0020022613730624438
    beta1: 0.046914599109421654
    beta2: 0.019381017240822685
    beta3: 0.010655285691959324


In [None]:
#ADAM: 
# Params: 
#     lr: 0.0012161288828451268
# SGD
# lr: 0.0004399074848905889
#     sgd_momentum: 0.782315314217227

# Params Adan: 
#     lr: 0.0020022613730624438
#     beta1: 0.046914599109421654
#     beta2: 0.019381017240822685
#     beta3: 0.010655285691959324

In [10]:
studies

{'adam_study': 'adam_study.csv',
 'sgd_study': 'sgd_study.csv',
 'adan_study': 'adan_study.csv'}

In [12]:
study.trials_dataframe().to_csv('study_trials.csv')

In [22]:

def recover_study(file_name='study_trials.csv'):
    trials_df = pd.read_csv(file_name)
    study = optuna.create_study()

    # Add trials back to the study
    for _, row in trials_df.iterrows():
        trial_params = {col.replace('params_',''): row[col] for col in trials_df.columns if col.startswith('params_')}
        trial = optuna.trial.create_trial(
            params=trial_params,
            distributions={
            'lr': optuna.distributions.FloatDistribution(1e-5, 1e-1),
            'beta1': optuna.distributions.FloatDistribution(1e-3, 1e-1),
            'beta2': optuna.distributions.FloatDistribution(1e-3, 1e-1),
            'beta3': optuna.distributions.FloatDistribution(1e-3, 1e-1)
            },
            value=row['value'],
            state=optuna.trial.TrialState.COMPLETE
        )
        # lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
        study.add_trial(trial)
    return study



In [23]:
for study_name, study_file in studies.items():
    if  study_name == 'adan_study':
        study = recover_study(study_file)
        break
optuna.visualization.plot_param_importances(study)

[I 2024-08-07 22:48:41,163] A new study created in memory with name: no-name-c820f13d-0507-4c91-a775-02d70e60a53c


In [24]:
optuna.visualization.plot_contour(study, params=["beta3","lr"])

In [None]:
#process - 
#1. show usage
#2. 
#3 taking model from scratch - 
# finetuning - 
#4 small bert - 
#5 load from scratch -

#just if we train from zero we can see the difference
#

#main question - optuna on large data 