In [None]:
import sys

import dc_check.logger as log
log.add(sink=sys.stderr, level="INFO")


In [None]:
import torch
from torchvision import datasets, transforms
from dc_check.utils.datasets.images.mnist import load_mnist
from dc_check.utils.datasets.images.cifar import load_cifar

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset = "cifar"
model_name = "LeNet"

if dataset == "mnist":
    X_train, y_train, X_test, y_test = load_mnist()
elif dataset == "cifar":
    X_train, y_train, X_test, y_test = load_cifar()
else:
    raise ValueError("Invalid dataset!")

In [None]:

import dc_check.plugins.core.models.image_nets as im_nets
from dc_check.plugins.core.datahandler import DataHandler
import torch
import torch.nn as nn

# Instantiate the neural network 
if dataset == 'cifar':
    if model_name == 'LeNet':
        model = im_nets.LeNet(num_classes=10).to(DEVICE)
    if model_name == 'ResNet':
        model = im_nets.ResNet18().to(DEVICE)
elif dataset == 'mnist':
    if model_name == 'LeNet':
        model = im_nets.LeNetMNIST(num_classes=10).to(DEVICE)
    if model_name == 'ResNet':
        model = im_nets.ResNet18MNIST().to(DEVICE)


# creating our optimizer and loss function object
learning_rate = 0.01
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)

datahandler = DataHandler(X_train, y_train, batch_size=64)

In [None]:
# dc_check absolute
from dc_check.plugins import Plugins

Plugins().list()


In [None]:

hcm = Plugins().get(
    "allsh",
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    lr=learning_rate,
    epochs=2,
    num_classes=10,
    logging_interval=1,
)
hcm.fit(
    datahandler=datahandler,
)

In [None]:
hcm_scores = hcm.scores
print(hcm_scores)
hcm.plot_scores()

In [None]:
import pandas as pd
print(hcm.name())
print(hcm.hard_direction())
hardest_5 = hcm.extract_datapoints(method="top_n", n=5)
display(pd.DataFrame(
    data={
        "indices":hardest_5[0][2],
        "labels": hardest_5[0][1],
        "scores": hardest_5[1]}))