In [2]:
import os
import sys

# Add the parent directory of the current script and 'src' folder to the path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

from src.model_unet import *

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# Load your saved data
test_data = torch.load("./data/prepared_datasets/train_flowers.pt")
test_labels = torch.load("./data/prepared_datasets/train_flowers_labels.pt")
val_data = torch.load("./data/prepared_datasets/val_flowers.pt")
val_labels = torch.load("./data/prepared_datasets/val_flowers_labels.pt")
train_data = torch.load("./data/prepared_datasets/test_flowers.pt")
train_labels = torch.load("./data/prepared_datasets/test_flowers_labels.pt")

In [18]:
batch_size = 32
image_size = 64
channels = 3
epochs = 20

device = "cuda" if torch.cuda.is_available() else "cpu"

In [6]:
def prepare_data_loaders(train_data, train_labels, val_data, val_labels, test_data, test_labels, batch_size=batch_size):
    train_dataset = TensorDataset(train_data, train_labels)
    val_dataset = TensorDataset(val_data, val_labels)
    test_dataset = TensorDataset(test_data, test_labels)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    return train_loader, val_loader, test_loader

In [7]:
# Normalize data to [0, 1] if not already done
train_data = (train_data - train_data.min()) / (train_data.max() - train_data.min())
val_data = (val_data - val_data.min()) / (val_data.max() - val_data.min())
test_data = (test_data - test_data.min()) / (test_data.max() - test_data.min())

train_data = train_data * 2 - 1
val_data = val_data * 2 - 1
test_data = test_data * 2 - 1
# Create data loaders
train_loader, val_loader, test_loader = prepare_data_loaders(train_data, train_labels, val_data, val_labels, test_data, test_labels)

In [8]:
from pathlib import Path

def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr

results_folder = Path("./results")
results_folder.mkdir(exist_ok = True)
save_and_sample_every = 1000

In [9]:
from torchmetrics.image.fid import FrechetInceptionDistance

def compute_fid(model, real_images, generated_images):
    fid = FrechetInceptionDistance(feature=64)
    fid.update(real_images, real=True)
    fid.update(generated_images, real=False)
    return fid.compute()

In [10]:
import torch
import torch.nn.functional as F
from torchvision.models import inception_v3

def compute_inception_score(generated_images, splits=10):
    inception_model = inception_v3(pretrained=True, transform_input=False)
    inception_model.eval()
    
    # Get predictions
    with torch.no_grad():
        preds = F.softmax(inception_model(generated_images), dim=1)
    
    # Compute inception score
    scores = []
    for i in range(splits):
        part = preds[i * (preds.shape[0] // splits): (i+1) * (preds.shape[0] // splits), :]
        py = torch.mean(part, dim=0)
        KL = torch.mean(torch.sum(part * (torch.log(part + 1e-12) - torch.log(py.unsqueeze(0) + 1e-12)), dim=1))
        scores.append(torch.exp(KL))
    
    return torch.mean(torch.stack(scores))

In [11]:
from scipy.stats import entropy
import numpy as np

def compute_kid(real_features, generated_features):
    # Note: This requires pre-computed features from Inception network
    def _compute_mmd(K):
        m = K.shape[0]
        diag_idx = np.arange(m)
        
        # MMD2 estimator
        t1 = 1 / (m * (m - 1)) * np.sum(K[~np.eye(m, dtype=bool)])
        t2 = 1 / (m * m) * np.sum(K[diag_idx, diag_idx])
        
        return t1 - 2 * t2
    
    def _polynomial_kernel(X, Y):
        return (1.0 + np.dot(X, Y.T)) ** 3
    
    # Compute kernel matrices
    K_xx = _polynomial_kernel(real_features, real_features)
    K_yy = _polynomial_kernel(generated_features, generated_features)
    K_xy = _polynomial_kernel(real_features, generated_features)
    
    # Compute KID
    return _compute_mmd(K_xx) + _compute_mmd(K_yy) - 2 * _compute_mmd(K_xy)

In [14]:
!pip install optuna

Collecting optuna
  Downloading optuna-4.1.0-py3-none-any.whl.metadata (16 kB)
Collecting alembic>=1.5.0 (from optuna)
  Downloading alembic-1.14.0-py3-none-any.whl.metadata (7.4 kB)
Collecting colorlog (from optuna)
  Downloading colorlog-6.9.0-py3-none-any.whl.metadata (10 kB)
Collecting sqlalchemy>=1.4.2 (from optuna)
  Downloading SQLAlchemy-2.0.36-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.7 kB)
Collecting Mako (from alembic>=1.5.0->optuna)
  Downloading Mako-1.3.7-py3-none-any.whl.metadata (2.9 kB)
Collecting greenlet!=0.4.17 (from sqlalchemy>=1.4.2->optuna)
  Downloading greenlet-3.1.1-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (3.8 kB)
Downloading optuna-4.1.0-py3-none-any.whl (364 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m364.4/364.4 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hDownloading alembic-1.14.0-py3-none-any.whl (233 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [19]:
import optuna
import torch
import numpy as np

def objective(trial):
    # Hyperparameters to optimize
    lr = trial.suggest_loguniform('lr', 1e-4, 1e-2)
    optimizer_type = trial.suggest_categorical('optimizer', ['Adam', 'AdamW'])
    dim_mults = trial.suggest_categorical('dim_mults', 
        [(1, 2, 4), (1, 2, 4, 8), (1, 2, 3, 4)])
    
    # Reinitialize model and optimizer with suggested hyperparameters
    model = Unet(
        dim=image_size,
        channels=channels,
        dim_mults=dim_mults
    )
    model.to(device)
    
    if optimizer_type == 'Adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    else:
        optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    
    # Training loop
    for epoch in range(epochs):
        for batch_idx, (data, _) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")):
            data = data.to(device)
            optimizer.zero_grad()
    
            t = torch.randint(0, timesteps, (batch_size,), device=device).long()
            if data.shape[0] == batch_size:                
                loss = p_losses(model, data, t, loss_type="huber", diffusion_params=diffusion_params)                
                if batch_idx % 100 == 0:
                    print("Loss:", loss.item())                
                loss.backward()
                optimizer.step()
    
    # Compute evaluation metrics
    fid_score = compute_fid(model)
    inception_score = compute_inception_score(model)
    kid_score = compute_kid(model)
    
    # Combine scores (you might want a more sophisticated combination)
    combined_score = (fid_score + (1 - inception_score) + kid_score) / 3
    
    return combined_score

# Create a study object and optimize the objective function
study = optuna.create_study(direction='minimize')
study.optimize(objective, n_trials=10)

# Print best parameters
print('Best trial:')
trial = study.best_trial
print('  Value: ', trial.value)
print('  Params: ')
for key, value in trial.params.items():
    print('    {}: {}'.format(key, value))

[I 2024-12-07 15:06:49,078] A new study created in memory with name: no-name-f631248d-cb60-4b59-963c-80a493151d9d
  lr = trial.suggest_loguniform('lr', 1e-4, 1e-2)
Epoch 1/20:   0%|          | 0/193 [00:00<?, ?it/s]

Loss: 0.46857255697250366


Epoch 1/20:  52%|█████▏    | 101/193 [00:23<00:17,  5.34it/s]

Loss: 3119.763427734375


Epoch 1/20: 100%|██████████| 193/193 [00:40<00:00,  4.78it/s]
Epoch 2/20:   1%|          | 1/193 [00:00<00:34,  5.59it/s]

Loss: 151.794677734375


Epoch 2/20:  52%|█████▏    | 101/193 [00:18<00:17,  5.33it/s]

Loss: 74.5447006225586


Epoch 2/20: 100%|██████████| 193/193 [00:36<00:00,  5.35it/s]
Epoch 3/20:   1%|          | 1/193 [00:00<00:36,  5.24it/s]

Loss: 47.12054443359375


Epoch 3/20:  52%|█████▏    | 101/193 [00:19<00:17,  5.36it/s]

Loss: 34.71089553833008


Epoch 3/20: 100%|██████████| 193/193 [00:36<00:00,  5.30it/s]
Epoch 4/20:   1%|          | 1/193 [00:00<00:36,  5.31it/s]

Loss: 25.24637222290039


Epoch 4/20:  52%|█████▏    | 101/193 [00:19<00:17,  5.22it/s]

Loss: 17.466808319091797


Epoch 4/20: 100%|██████████| 193/193 [00:36<00:00,  5.30it/s]
Epoch 5/20:   1%|          | 1/193 [00:00<00:40,  4.69it/s]

Loss: 17.85028076171875


Epoch 5/20:  52%|█████▏    | 101/193 [00:19<00:17,  5.20it/s]

Loss: 16.093585968017578


Epoch 5/20: 100%|██████████| 193/193 [00:36<00:00,  5.29it/s]
Epoch 6/20:   1%|          | 1/193 [00:00<00:35,  5.43it/s]

Loss: 19.390840530395508


Epoch 6/20:  52%|█████▏    | 101/193 [00:19<00:19,  4.78it/s]

Loss: 21.470792770385742


Epoch 6/20: 100%|██████████| 193/193 [00:36<00:00,  5.27it/s]
Epoch 7/20:   1%|          | 1/193 [00:00<00:38,  4.93it/s]

Loss: 30.71481704711914


Epoch 7/20:  52%|█████▏    | 101/193 [00:19<00:17,  5.23it/s]

Loss: 10.096416473388672


Epoch 7/20: 100%|██████████| 193/193 [00:36<00:00,  5.32it/s]
Epoch 8/20:   1%|          | 1/193 [00:00<00:36,  5.27it/s]

Loss: 10.780482292175293


Epoch 8/20:  52%|█████▏    | 101/193 [00:19<00:17,  5.33it/s]

Loss: 19.589126586914062


Epoch 8/20: 100%|██████████| 193/193 [00:36<00:00,  5.34it/s]
Epoch 9/20:   1%|          | 1/193 [00:00<00:35,  5.34it/s]

Loss: 5.420480251312256


Epoch 9/20:  52%|█████▏    | 101/193 [00:19<00:17,  5.22it/s]

Loss: 11.419341087341309


Epoch 9/20: 100%|██████████| 193/193 [00:36<00:00,  5.33it/s]
Epoch 10/20:   1%|          | 1/193 [00:00<00:35,  5.40it/s]

Loss: 8.13436508178711


Epoch 10/20:  52%|█████▏    | 101/193 [00:19<00:17,  5.19it/s]

Loss: 8.755369186401367


Epoch 10/20: 100%|██████████| 193/193 [00:36<00:00,  5.32it/s]
Epoch 11/20:   1%|          | 1/193 [00:00<00:37,  5.16it/s]

Loss: 332.0542907714844


Epoch 11/20:  52%|█████▏    | 101/193 [00:19<00:17,  5.27it/s]

Loss: 8.821314811706543


Epoch 11/20: 100%|██████████| 193/193 [00:36<00:00,  5.26it/s]
Epoch 12/20:   1%|          | 1/193 [00:00<00:37,  5.08it/s]

Loss: 8.913641929626465


Epoch 12/20:  52%|█████▏    | 101/193 [00:19<00:17,  5.33it/s]

Loss: 8.420626640319824


Epoch 12/20: 100%|██████████| 193/193 [00:36<00:00,  5.29it/s]
Epoch 13/20:   1%|          | 1/193 [00:00<00:36,  5.25it/s]

Loss: 5.2083587646484375


Epoch 13/20:  52%|█████▏    | 101/193 [00:19<00:17,  5.27it/s]

Loss: 2.9962947368621826


Epoch 13/20: 100%|██████████| 193/193 [00:36<00:00,  5.27it/s]
Epoch 14/20:   1%|          | 1/193 [00:00<00:36,  5.32it/s]

Loss: 6.999876976013184


Epoch 14/20:  52%|█████▏    | 101/193 [00:19<00:17,  5.22it/s]

Loss: 5.490109920501709


Epoch 14/20: 100%|██████████| 193/193 [00:36<00:00,  5.28it/s]
Epoch 15/20:   1%|          | 1/193 [00:00<00:36,  5.22it/s]

Loss: 3.02449369430542


Epoch 15/20:  52%|█████▏    | 101/193 [00:18<00:17,  5.39it/s]

Loss: 3.329272747039795


Epoch 15/20: 100%|██████████| 193/193 [00:36<00:00,  5.33it/s]
Epoch 16/20:   1%|          | 1/193 [00:00<00:34,  5.59it/s]

Loss: 4.26408052444458


Epoch 16/20:  52%|█████▏    | 101/193 [00:19<00:16,  5.47it/s]

Loss: 2.2728443145751953


Epoch 16/20: 100%|██████████| 193/193 [00:35<00:00,  5.36it/s]
Epoch 17/20:   1%|          | 1/193 [00:00<00:34,  5.57it/s]

Loss: 2.9546689987182617


Epoch 17/20:  52%|█████▏    | 101/193 [00:18<00:17,  5.35it/s]

Loss: 3.290975570678711


Epoch 17/20: 100%|██████████| 193/193 [00:35<00:00,  5.43it/s]
Epoch 18/20:   1%|          | 1/193 [00:00<00:35,  5.47it/s]

Loss: 3.10042142868042


Epoch 18/20:  52%|█████▏    | 101/193 [00:19<00:17,  5.29it/s]

Loss: 2.9758822917938232


Epoch 18/20: 100%|██████████| 193/193 [00:35<00:00,  5.37it/s]
Epoch 19/20:   1%|          | 1/193 [00:00<00:36,  5.32it/s]

Loss: 2.9402170181274414


Epoch 19/20:  52%|█████▏    | 101/193 [00:18<00:16,  5.47it/s]

Loss: 2.4033048152923584


Epoch 19/20: 100%|██████████| 193/193 [00:35<00:00,  5.47it/s]
Epoch 20/20:   1%|          | 1/193 [00:00<00:37,  5.15it/s]

Loss: 2.472158670425415


Epoch 20/20:  52%|█████▏    | 101/193 [00:18<00:17,  5.21it/s]

Loss: 3.9144415855407715


Epoch 20/20: 100%|██████████| 193/193 [00:35<00:00,  5.47it/s]
[W 2024-12-07 15:18:57,686] Trial 0 failed with parameters: {'lr': 0.005561529974922827, 'optimizer': 'Adam', 'dim_mults': (1, 2, 4, 8)} because of the following error: TypeError("compute_fid() missing 2 required positional arguments: 'real_images' and 'generated_images'").
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/optuna/study/_optimize.py", line 197, in _run_trial
    value_or_values = func(trial)
  File "/tmp/ipykernel_390/2195638399.py", line 40, in objective
    fid_score = compute_fid(model)
TypeError: compute_fid() missing 2 required positional arguments: 'real_images' and 'generated_images'
[W 2024-12-07 15:18:57,687] Trial 0 failed with value None.


TypeError: compute_fid() missing 2 required positional arguments: 'real_images' and 'generated_images'