In [None]:
import time
import sys
import os
import itertools
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from torchvision.models.efficientnet import MBConvConfig, FusedMBConvConfig

import optuna

sys.path.append("/jet/home/azhang19/stat 214/stat-214-lab2-group6/code/modeling")
from preprocessing import to_NCHW, pad_to_384x384, standardize_images
from autoencoder import EfficientNetEncoder, EfficientNetDecoder, AutoencoderConfig
from classification import train_and_validate

device = "cuda" if torch.cuda.is_available() else "cpu"

torch.set_float32_matmul_precision('high')
torch.backends.cudnn.benchmark = True

use_amp = True

In [None]:
# Load and preprocess data
data = np.load("/jet/home/azhang19/stat 214/stat-214-lab2-group6/data/array_data.npz")
unlabeled_images, unlabeled_masks, labeled_images, labeled_masks, labels = data["unlabeled_images"], data["unlabeled_masks"], data["labeled_images"], data["labeled_masks"], data["labels"]

unlabeled_images = pad_to_384x384(to_NCHW(unlabeled_images))
unlabeled_masks = pad_to_384x384(unlabeled_masks)

labeled_images = pad_to_384x384(to_NCHW(labeled_images))
labeled_masks = pad_to_384x384(labeled_masks)
labels = pad_to_384x384(labels)

# Convert to tensors and move to GPU
unlabeled_images = torch.tensor(unlabeled_images, dtype=torch.float32).to(device)  # [161, 8, 384, 384]
unlabeled_masks = torch.tensor(unlabeled_masks, dtype=torch.bool).to(device)    # [161, 384, 384]

labeled_images = torch.tensor(labeled_images, dtype=torch.float32).to(device)      # [3, 8, 384, 384]
labeled_masks = torch.tensor(labeled_masks, dtype=torch.bool).to(device)        # [3, 384, 384]
labels = torch.tensor(labels, dtype=torch.long).to(device)                      # [3, 384, 384]


# Standardize images
unlabeled_images, std_channel, mean_channel = standardize_images(unlabeled_images, unlabeled_masks)
labeled_images, _, _ = standardize_images(labeled_images, labeled_masks, std_channel, mean_channel)

In [None]:
# Load model and extract features
ckpt_path = "/jet/home/azhang19/stat 214/stat-214-lab2-group6/code/modeling/ckpt"
saved_epoch = [100, 200, 400, 800, 1600, 3200, 6400, 12800, 15000, 20000, 25000, 30000, 35000, 40000]

all_features = torch.zeros((2, 2, 2, 2, 2, len(saved_epoch), 3, 64, 384, 384)).to(device)

num_layers_per_block = [1, 2]
augementation = [True, False]
for block1, block2, block3, flip, rotate in itertools.product(*[num_layers_per_block] * 3,*[augementation] * 2):
    config = AutoencoderConfig(
        num_layers_block=[block1, block2, block3],
        augmentation_flip=flip,
        augmentation_rotate=rotate
    )
    encoder_config = [
        FusedMBConvConfig(1, 3, 1, 16, 16, config.num_layers_block[0]),  # 384x384x8 -> 384x384x16
        FusedMBConvConfig(4, 3, 2, 16, 32, config.num_layers_block[1]),  # 384x384x16 -> 192x192x32
        MBConvConfig(4, 3, 2, 32, 64, config.num_layers_block[2]),       # 192x192x32 -> 96x96x64
    ]

    encoder = EfficientNetEncoder(
        inverted_residual_setting=encoder_config,
        dropout=0.1,
        input_channels=8,
        last_channel=64,
    )

    decoder = EfficientNetDecoder()
    autoencoder = nn.Sequential(encoder, decoder).train().to(device)

    folder_path = os.path.join(ckpt_path, str(config))
    
    for i, epoch in enumerate(saved_epoch):
        autoencoder.load_state_dict(torch.load(os.path.join(folder_path, f"autoencoder_{epoch}.pth")))
        autoencoder.eval()
        with torch.inference_mode():
            encoder = autoencoder[0]
            features = encoder(labeled_images)
            features = nn.functional.interpolate(features, size=384, mode="bicubic", antialias=True)
            all_features[block1-1, block2-1, block3-1, int(flip), int(rotate), i] = features

In [None]:
# Load saved study
import pickle
with open("/jet/home/azhang19/optuna_study_autoencoder.pkl", "rb") as f:
    study = pickle.load(f)

In [None]:
# Get best trial
study.best_trial

In [None]:
# Get best feature
feature = all_features[study.best_params["block1"]-1, study.best_params["block2"]-1, study.best_params["block3"]-1, int(study.best_params["flip"]), int(study.best_params["rotate"]), study.best_params["autoencoder_epoch"]]

feature, _, _ = standardize_images(feature, labeled_masks)

if study.best_params['with_orginal']:
    feature = torch.cat([feature, labeled_images], dim=1)

In [None]:
# Repeat CV with best hyperparameters
num_repeat = 100
record = np.zeros((num_repeat, 2))
classifiers = []

for repeat in range(num_repeat):

    train_val_idx = [0, 1]

    # Container for metrics from each fold
    fold_records_f1 = torch.zeros(len(train_val_idx))  # For F1 (objective)
    fold_records_acc = torch.zeros(len(train_val_idx))  # For accuracy (logging)

    classifiers_in_this_repeat = []
    # Assuming feature and labels are defined globally (e.g., torch tensors)
    for i in train_val_idx:
        # Leave-one-out style split
        train_idx = [j for j in train_val_idx if j != i]
        val_idx = [i]

        # Get training and validation data
        train_data = feature[train_idx]
        train_labels = labels[train_idx]
        val_data = feature[val_idx]
        val_labels = labels[val_idx]

        # Train and validate, get both F1 and accuracy
        classifier, val_f1, val_acc = train_and_validate(
            train_data=train_data,
            train_labels=train_labels,
            val_data=val_data,
            val_labels=val_labels,
            in_channels=feature.shape[1],
            num_layers=study.best_params["num_layers"],
            kernel_size=study.best_params["kernel_size"],
            hidden_channels=study.best_params["hidden_channels"],
            epochs=study.best_params["epochs"],
            lr=study.best_params["lr"],
            weight_decay=study.best_params["weight_decay"],
            optimizer_class=torch.optim.AdamW if study.best_params['optimizer'] == 'AdamW' else torch.optim.SGD,
            loss_mix_ratio=study.best_params["loss_mix_ratio"],
            l1=study.best_params["l1"],
            class_weight=study.best_params['class_weight'],
            device=device,
            return_classifier=True,
        )
        classifiers_in_this_repeat.append((train_idx, classifier))

        fold_records_f1[i] = val_f1
        fold_records_acc[i] = val_acc

    classifiers.append(classifiers_in_this_repeat)
    avg_f1 = fold_records_f1.mean().item()
    avg_acc = fold_records_acc.mean().item()

    record[repeat] = [avg_f1, avg_acc]

record.mean(axis=0), record.std(axis=0)

In [None]:
# Train on the whole dataset

classifiers = []
record = np.zeros((num_repeat, 2))

for repeat in range(num_repeat):
    classifiers_in_this_repeat = []
    # Get training and validation data
    train_data = feature[0:2]
    train_labels = labels[0:2]
    val_data = feature[[2]]
    val_labels = labels[[2]]

    # Train and validate, get both F1 and accuracy
    classifier, val_f1, val_acc = train_and_validate(
        train_data=train_data,
        train_labels=train_labels,
        val_data=val_data,
        val_labels=val_labels,
        in_channels=feature.shape[1],
        num_layers=study.best_params["num_layers"],
        kernel_size=study.best_params["kernel_size"],
        hidden_channels=study.best_params["hidden_channels"],
        epochs=study.best_params["epochs"],
        lr=study.best_params["lr"],
        weight_decay=study.best_params["weight_decay"],
        optimizer_class=torch.optim.AdamW if study.best_params['optimizer'] == 'AdamW' else torch.optim.SGD,
        loss_mix_ratio=study.best_params["loss_mix_ratio"],
        l1=study.best_params["l1"],
        class_weight=study.best_params['class_weight'],
        device=device,
        return_classifier=True,
    )
    classifiers_in_this_repeat.append((train_idx, classifier))

    record[repeat] = [val_f1.item(), val_acc.item()]

    classifiers.append(classifiers_in_this_repeat)

record.mean(axis=0), record.std(axis=0)