In [None]:
import os
from pathlib import Path
from joblib import Parallel, delayed
from tqdm import tqdm
import shutil
import itertools
import pytorch_lightning as pl
from utils import PyLModel, find_latest_checkpoint_path

data_dir = Path.cwd() / "data_select"
dataset_dir = data_dir / "dataset"
log_dir = data_dir / "log"

In [None]:
device = "cuda"

max_epochs = 100

lr=0.0001
#lr=0.00001

batch_size=256
validation_data_ratio=0.05
scheduler_step=10
scheduler_gamma=0.5

In [None]:
checkpoint_path = find_latest_checkpoint_path(log_dir / "lightning_logs")

if (checkpoint_path is None):
    model = PyLModel(
        dataset_dir=dataset_dir,
        category="select",
        lr=lr,
        batch_size=batch_size,
        validation_data_ratio=validation_data_ratio,
        scheduler_step=scheduler_step,
        scheduler_gamma=scheduler_gamma,
        train_transform="full_augmentation",
    )
    print("No checkpoint found.")
else:
    model = PyLModel.load_from_checkpoint(
        str(checkpoint_path),
        dataset_dir=dataset_dir,
        category="select",
        lr=lr,
        batch_size=batch_size,
        validation_data_ratio=validation_data_ratio,
        scheduler_step=scheduler_step,
        scheduler_gamma=scheduler_gamma,
        train_transform="full_augmentation",
    )
    print("Load:", checkpoint_path)

In [None]:
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    save_top_k = 1,
    verbose=True,
    monitor = "avg_val_loss",
    mode = "min",
)

trainer = pl.Trainer(gpus=[0],
                     max_epochs=max_epochs,
                     checkpoint_callback=checkpoint_callback,
                     default_root_dir=log_dir,
                    )

In [None]:
model  = model.to(device)
trainer.fit(model)