In [None]:
import numpy as np
import timm
import tlc
from cleanlab.classification import CleanLearning
from skorch import NeuralNetClassifier
from torch import nn


In [None]:
table = tlc.Table.from_names("val", "chesspieces-val", "chessvision-classification")

def transform_image(img):
    arr = np.expand_dims(np.array(img, dtype=np.float32), 0) / 255.
    return arr

images = np.array([transform_image(row[0]) for row in table])
labels = np.array([row[1] for row in table], dtype=np.int64)


In [None]:
model = timm.create_model("resnet18", num_classes=13, in_chans=1)
skorch_model = NeuralNetClassifier(
    model,
    max_epochs=100,
    criterion=nn.CrossEntropyLoss(),
    lr=0.1,
    batch_size=128,
    optimizer__weight_decay=0.0001,
    train_split=None,
    device="cuda",
)

cleanlearning = CleanLearning(clf=skorch_model)
label_issues_info = cleanlearning.find_label_issues(images, labels)


In [None]:
label_issues_info

In [None]:
value_map = table.get_value_map("label")
labels = [v["internal_name"] for v in value_map.values()]

In [None]:
run = tlc.init(
    "chessvision-classification",
    run_name="cleanlab-testing",
)
run.add_metrics_data(
    label_issues_info.to_dict(orient="list"),
    input_table_url=table.url,
    override_column_schemas={
        "given_label": tlc.CategoricalLabel("label", labels),
        "predicted_label": tlc.CategoricalLabel("label", labels),
    },
)
run.set_status_completed()
