# TUTORIAL ON HOW TO USE H-CAT

# IMPORTS

In [None]:
from src.trainer import PyTorchTrainer
from src.dataloader import MultiFormatDataLoader
from src.models import *
from src.evaluator import Evaluator

# Define experimental parameters

In [None]:
# Hardness types:
# - "uniform": Uniform mislabeling
# - "asymmetric": Asymmetric mislabeling
# - "adjacent" : Adjacent mislabeling
# - "instance": Instance-specific mislabeling
# - "ood_covariate": Near-OOD Covariate Shift
# - "domain_shift": Specific type of Near-OOD
# - "far_ood": Far-OOD shift (out-of-support)
# - "zoom_shift": Zoom shift  - type of Atypical for images
# - "crop_shift": Crop shift  - type of Atypical for images



hardness = "uniform"
p=0.1
dataset = "mnist"
model_name = "LeNet"
epochs = 10
seed = 0

# Defined by prior or domain knowledge
if hardness =="instance":
    if dataset == "mnist":
        rule_matrix = {
                    1: [7],
                    2: [7],
                    3: [8],
                    4: [4],
                    5: [6],
                    6: [5],
                    7: [1, 2],
                    8: [3],
                    9: [7],
                    0: [0]
                }
    if dataset == "cifar":

        rule_matrix = {
                    0: [2],   # airplane (unchanged)
                    1: [9],   # automobile -> truck
                    2: [9],   # bird (unchanged)
                    3: [5],   # cat -> automobile
                    4: [5,7],   # deer (unchanged)
                    5: [3, 4],   # dog -> cat
                    6: [6],   # frog (unchanged)
                    7: [5],   # horse -> dog
                    8: [7],   # ship (unchanged)
                    9: [9],   # truck -> horse
                }

else:
    rule_matrix = None



# Define HCMs to evaluate --- if unspecified we will evaluate all

In [None]:
characterization_methods =  [
            "aum",
            "data_uncert", # for both Data-IQ and Data-Maps
            "el2n",
            "grand",
            "cleanlab",
            "forgetting",
            "vog",
            "prototypicality",
            "allsh",
            "loss",
            "conf_agree",
            "detector"
        ]

# Load datasets

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

if dataset == 'cifar':
    # Define transforms for the dataset
    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(size=32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # Load the CIFAR-10 dataset
    train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
    test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)

elif dataset =='mnist':
    # Define transforms for the dataset
    transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))])


    # Load the MNIST dataset
    train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)


total_samples = len(train_dataset)
num_classes = 10

# Set device to use
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# STEP 1: Dataloader module

In [None]:
# Allows importing data in multiple formats

dataloader_class = MultiFormatDataLoader(data=train_dataset,
                                        target_column=None,
                                        data_type='torch_dataset',
                                        data_modality='image',
                                        batch_size=64,
                                        shuffle=True,
                                        num_workers=0,
                                        transform=None,
                                        image_transform=None,
                                        perturbation_method=hardness,
                                        p=p,
                                        rule_matrix=rule_matrix
        )


dataloader, dataloader_unshuffled = dataloader_class.get_dataloader()
flag_ids = dataloader_class.get_flag_ids()

# STEP 2: TRAINER module

In [None]:
# Instantiate the neural network and optimizer
if dataset == 'cifar':
    if model_name == 'LeNet':
        model = LeNet(num_classes=10).to(device)
    if model_name == 'ResNet':
        model = ResNet18().to(device)
elif dataset == 'mnist':
    if model_name == 'LeNet':
        model = LeNetMNIST(num_classes=10).to(device)
    if model_name == 'ResNet':
        model = ResNet18MNIST().to(device)
        
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


# Instantiate the PyTorchTrainer class
trainer = PyTorchTrainer(model=model,
                            criterion=criterion,
                            optimizer=optimizer,
                            lr=0.001,
                            epochs=epochs,
                            total_samples=total_samples,
                            num_classes=num_classes,
                            characterization_methods=characterization_methods,
                            device=device)

# Train the model
trainer.fit(dataloader, dataloader_unshuffled)

hardness_dict = trainer.get_hardness_methods()

# STEP 3: Evaluator module

In [None]:
eval = Evaluator(hardness_dict=hardness_dict, flag_ids=flag_ids, p=p)

eval_dict, raw_scores_dict = eval.compute_results()


# Show results

In [None]:
eval_dict