In [None]:


import os
import torch
import torch.optim as optim
import time
import numpy as np

from network_pathmnist import *
from tools_pathmnist import *

torch.multiprocessing.set_sharing_strategy('file_system')



def get_config():
    config = {
        "alpha": 0.1,
        # "optimizer":{"type":  optim.SGD, "optim_params": {"lr": 0.05, "weight_decay": 10 ** -5}},
        "optimizer": {"type": optim.RMSprop, "optim_params": {"lr": 1e-5, "weight_decay": 10 ** -5}},
        "info": "[DNNH]",
        "resize_size": 256,
        "crop_size": 224,
        "batch_size": 64,
        "net": AlexNet,
        # "net":ResNet,
        # "dataset": "cifar10",
        # "dataset": "cifar10-1",
        "dataset": "pathmnist",
        "epoch": 250,
        "test_map": 15,
        "save_path": "save/DNNH",
        # "device":torch.device("cpu"),
        "device": torch.device("cuda:1"),
        "bit_list": [48],
    }
    config = config_dataset(config)
    return config


class DNNHLoss(torch.nn.Module):
    def __init__(self, config, bit):
        super(DNNHLoss, self).__init__()

    def forward(self, b, b_pos, b_neg):
        # compute the L2 norms of the diffs
        diff_pos = torch.norm(b - b_pos, p=2, dim=1) ** 2
        diff_neg = torch.norm(b - b_neg, p=2, dim=1) ** 2

        # compute the triplet ranking hinge loss
        loss = torch.max(0, diff_pos - diff_neg + 1)

        return loss.mean()


def train_val(config, bit):
    device = config["device"]
    train_loader, test_loader, dataset_loader, num_train, num_test, num_dataset = get_data(config)
    config["num_train"] = num_train
    net = config["net"](bit).to(device)

    optimizer = config["optimizer"]["type"](net.parameters(), **(config["optimizer"]["optim_params"]))

    criterion = DNNHLoss(config, bit)

    Best_mAP = 0

    for epoch in range(config["epoch"]):

        current_time = time.strftime('%H:%M:%S', time.localtime(time.time()))

        print("%s[%2d/%2d][%s] bit:%d, dataset:%s, training...." % (
            config["info"], epoch + 1, config["epoch"], current_time, bit, config["dataset"]), end="")

        net.train()

        train_loss = 0
        for image, label, ind in train_loader:
            image = image.to(device)
            label = label.to(device)

            optimizer.zero_grad()
            u = net(image)

            loss = criterion(u, label.float(), ind, config)
            train_loss += loss.item()

            loss.backward()
            optimizer.step()

        train_loss = train_loss / len(train_loader)

        print("\b\b\b\b\b\b\b loss:%.3f" % (train_loss))

        if (epoch + 1) % config["test_map"] == 0:
            Best_mAP = validate(config, Best_mAP, test_loader, dataset_loader, net, bit, epoch, num_dataset)


if __name__ == "__main__":
    config = get_config()
    print(config)
    for bit in config["bit_list"]:
        config["pr_curve_path"] = f"log/alexnet/DNNH_{config['dataset']}_{bit}.json"
        train_val(config, bit)
