In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
desired_cpu_cores = "0-12"
pid = os.getpid()
os.system(f"taskset -p -c {desired_cpu_cores} {pid}")

import torch
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm.auto import tqdm
import aim

from library.dataset import TrainDataset, TestDataset, ImageDataset
from library.model import VectorQuantizer, VQVAE, EnhancedVQVAE
from library.trainer import AdvancedTrainer
from library.threshold import ThresholdOptimizer
from library.evaluator import Evaluator

def run_experiment(
    optimizer_class=torch.optim.AdamW,
    optimizer_kwargs=None,
    model_class=EnhancedVQVAE,
    epochs=5,
    fine_tune_epochs=1,
    batch_size=512,
    use_perceptual=True,
    image_size=128,
):
    # Set up Aim run and log hyperparameters
    run = aim.Run()
    if optimizer_kwargs is None:
        optimizer_kwargs = {'lr': 1e-4, 'weight_decay': 1e-5}
    hparams = {
        "optimizer_class": optimizer_class.__name__,
        **optimizer_kwargs,
        "model_class": model_class.__name__,
        "epochs": epochs,
        "fine_tune_epochs": fine_tune_epochs,
        "batch_size": batch_size,
        "use_perceptual": use_perceptual,
        "image_size": image_size,
    }
    run["hparams"] = hparams

    # Device setup
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Define transforms
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])

    # Load datasets
    train_dir = "dataset/train"
    proliv_dir = "dataset/proliv"
    test_dir = "dataset/test/imgs"
    annotation_path = "dataset/test/test_annotation.txt"

    train_dataset = TrainDataset(train_dir, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    proliv_dataset = ImageDataset(proliv_dir, transform=transform)
    proliv_loader = DataLoader(proliv_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

    # Split datasets
    normal_train, normal_val = torch.utils.data.random_split(train_dataset, [0.8, 0.2])
    val_dataset = torch.utils.data.ConcatDataset([normal_val, proliv_dataset])
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

    # Initialize model and optimizer
    model = model_class().to(DEVICE)
    optimizer = optimizer_class(model.parameters(), **optimizer_kwargs)

    # Initialize trainer
    trainer = AdvancedTrainer(
        model=model,
        train_loader=train_loader,
        optimizer=optimizer,
        device=DEVICE,
        val_loader=val_loader,
        use_perceptual=use_perceptual,
        run=run
    )

    # Train the model
    trainer.train(epochs, fine_tune_epochs=fine_tune_epochs)

    # Load the best model
    model.load_state_dict(torch.load('final_model.pth')['model_state_dict'])

    # Threshold optimization
    threshold_optimizer = ThresholdOptimizer(
        model,
        DataLoader(normal_val, batch_size=batch_size),
        DataLoader(proliv_dataset, batch_size=batch_size, num_workers=0),
        DEVICE
    )
    optimal_threshold = threshold_optimizer.find_optimal_threshold()
    run.track(optimal_threshold, name='optimal_threshold')

    # Evaluation
    test_dataset = TestDataset(test_dir, annotation_path, transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

    evaluator = Evaluator(model, DEVICE)
    tpr, tnr = evaluator.evaluate(test_loader, optimal_threshold)
    train_errors = evaluator.compute_errors(DataLoader(train_dataset, batch_size=batch_size, num_workers=0))
    threshold = evaluator.determine_threshold(train_errors, 95)
    print(f"Final TPR: {tpr}, Final TNR: {tnr}")
    run.track(threshold, name='percentile_threshold')
    run.track(tpr, name='test_tpr')
    run.track(tnr, name='test_tnr')

    # Plot reconstructions
    # model.eval()
    # with torch.no_grad():
    #     test_images, _, _ = next(iter(test_loader))
    #     test_images = test_images[:10].to(DEVICE)
    #     reconstructions, _ = model(test_images)

    # plt.figure(figsize=(16, 4))
    # for i in range(10):
    #     plt.subplot(2, 10, i + 1)
    #     img = test_images[i].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5
    #     plt.imshow(img)
    #     plt.axis('off')
    #     plt.subplot(2, 10, i + 11)
    #     recon = reconstructions[i].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5
    #     plt.imshow(recon)
    #     plt.axis('off')
    # plt.tight_layout()
    # plt.show()

    return {
        "optimal_threshold": optimal_threshold,
        "percentile_threshold": threshold,
        "test_tpr": tpr,
        "test_tnr": tnr,
    }



pid 1378516's current affinity list: 0-79
pid 1378516's new affinity list: 0-12


#### SGD Optimizer

In [None]:
results = run_experiment(
    optimizer_class=torch.optim.SGD,
    optimizer_kwargs={'lr': 0.01, 'momentum': 0.9},
    epochs=10,
    batch_size=256,
    use_perceptual=False
)



  0%|          | 0/40 [00:00<?, ?it/s]


Saved new best model with val loss 0.0119

Epoch 1/10
Train Loss: 0.0127 | Val Loss: 0.0119


  0%|          | 0/40 [00:00<?, ?it/s]


Saved new best model with val loss 0.0116

Epoch 2/10
Train Loss: 0.0096 | Val Loss: 0.0116


  0%|          | 0/40 [00:00<?, ?it/s]

#### SGD optimizer with perceptual

In [None]:
results = run_experiment(
    optimizer_class=torch.optim.SGD,
    optimizer_kwargs={'lr': 0.01, 'momentum': 0.9},
    epochs=10,
    batch_size=256,
    use_perceptual=True
)

#### AdamW optimizer without perceptual

In [None]:
results = run_experiment(
    optimizer_class=torch.optim.AdamW,
    optimizer_kwargs={'lr':1e-4, 'weight_decay': 1e-5},
    epochs=5,
    batch_size=256,
    use_perceptual=False
)




  0%|          | 0/40 [00:00<?, ?it/s]


Saved new best model with val loss 0.0122

Epoch 1/5
Train Loss: 0.0146 | Val Loss: 0.0122


  0%|          | 0/40 [00:00<?, ?it/s]


Epoch 2/5
Train Loss: 0.0312 | Val Loss: 0.0535


  0%|          | 0/40 [00:00<?, ?it/s]


Epoch 3/5
Train Loss: 0.0516 | Val Loss: 0.0561


  0%|          | 0/40 [00:00<?, ?it/s]

#### AdamW optimizer with perceptual

In [None]:
results = run_experiment(
    optimizer_class=torch.optim.AdamW,
    optimizer_kwargs={'lr':1e-4, 'weight_decay': 1e-5},
    epochs=5,
    batch_size=256,
    use_perceptual=True
)
