In [1]:
import numpy as np
import timm
import tlc
from cleanlab.classification import CleanLearning
from skorch import NeuralNetClassifier
from torch import nn


In [2]:
table = tlc.Table.from_names("val", "chesspieces-val", "chessvision-classification")

def transform_image(img):
    arr = np.expand_dims(np.array(img, dtype=np.float32), 0) / 255.
    return arr

images = np.array([transform_image(row[0]) for row in table])
labels = np.array([row[1] for row in table], dtype=np.int64)


In [3]:
model = timm.create_model("resnet18", num_classes=13, in_chans=1)
skorch_model = NeuralNetClassifier(
    model,
    max_epochs=100,
    criterion=nn.CrossEntropyLoss(),
    lr=0.1,
    batch_size=128,
    optimizer__weight_decay=0.0001,
    train_split=None,
    device="cuda",
)

cleanlearning = CleanLearning(clf=skorch_model)
label_issues_info = cleanlearning.find_label_issues(images, labels)


  epoch    train_loss     dur
-------  ------------  ------
      1        [36m3.3286[0m  3.3816
      2        3.4674  0.2015
      3        [36m3.0173[0m  0.2012
      4        [36m2.6470[0m  0.2029
      5        [36m2.4745[0m  0.1973
      6        [36m2.3293[0m  0.2053
      7        [36m2.2045[0m  0.2067
      8        [36m2.0981[0m  0.2333
      9        [36m2.0064[0m  0.2230
     10        [36m1.9319[0m  0.2352
     11        [36m1.8667[0m  0.2129
     12        [36m1.7869[0m  0.2201
     13        [36m1.7214[0m  0.2156
     14        [36m1.6654[0m  0.2070
     15        [36m1.6082[0m  0.2280
     16        [36m1.5591[0m  0.2106
     17        [36m1.5027[0m  0.2051
     18        [36m1.4457[0m  0.2146
     19        [36m1.3935[0m  0.2116
     20        [36m1.3565[0m  0.2147
     21        [36m1.2936[0m  0.2183
     22        [36m1.2331[0m  0.2134
     23        [36m1.1789[0m  0.2143
     24        [36m1.1272[0m  0.2162
     25      

In [4]:
label_issues_info

Unnamed: 0,is_label_issue,label_quality,given_label,predicted_label
0,True,0.000420,0,3
1,True,0.100578,0,12
2,False,0.897569,0,0
3,False,0.235365,0,3
4,False,0.436249,0,0
...,...,...,...,...
2129,False,0.485460,12,12
2130,False,0.319342,12,12
2131,False,0.881453,12,12
2132,False,0.575315,12,12


In [5]:
value_map = table.get_value_map("label")
labels = [v["internal_name"] for v in value_map.values()]

In [8]:
run = tlc.init(
    "chessvision-classification",
    run_name="cleanlab-testing",
)
run.add_metrics_data(
    label_issues_info.to_dict(orient="list"),
    input_table_url=table.url,
    override_column_schemas={
        "given_label": tlc.CategoricalLabel("label", labels),
        "predicted_label": tlc.CategoricalLabel("label", labels),
    },
)
run.set_status_completed()


[90m3lc: [0mCreated new run at C:/Users/gudbrand/AppData/Local/3LC/3LC/projects/chessvision-classification/runs/cleanlab-testing_0001
