# Getting started with our ultimate beginner guide!

## This tutorial will walk you through the basics of using the `usb` lighting package. Let's get started by training a FixMatch model on CIFAR-10!

In [1]:
import sys
sys.path.append('../')

from semilearn import get_dataset, get_data_loader, get_net_builder, get_algorithm, get_config, Trainer
%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


## Step 1: define configs and create config

In [2]:
config = {
    # 'algorithm': 'fixmatch',
    'algorithm': 'cpmatch',
    'save_name': 'cpmatch',
    'net': 'vit_tiny_patch2_32',
    'use_pretrain': True, 
    'pretrain_path': 'https://github.com/microsoft/Semi-supervised-learning/releases/download/v.0.0.0/vit_tiny_patch2_32_mlp_im_1k_32.pth',

    # optimization configs
    'epoch': 1,  
    'num_train_iter': 5000,  
    'num_eval_iter': 500,  
    'num_log_iter': 50,  
    'optim': 'AdamW',
    'lr': 5e-4,
    'layer_decay': 0.5,
    'batch_size': 16,
    'eval_batch_size': 16,


    # dataset configs
    'dataset': 'cifar10',
    'num_labels': 4000,
    'num_classes': 10,
    'img_size': 32,
    'crop_ratio': 0.875,
    'data_dir': '/data/datasets',
    'ulb_samples_per_class': None,

    # algorithm specific configs
    'hard_label': True,
    'uratio': 2,
    'ulb_loss_ratio': 1.0,

    # device configs
    'gpu': 0,
    'world_size': 1,
    'distributed': False,
    "num_workers": 2,
}
config = get_config(config)

## Step 2: create model and specify algorithm

In [3]:
algorithm = get_algorithm(config,  get_net_builder(config.net, from_name=False), tb_log=None, logger=None)

Files already downloaded and verified
lb count: [400, 400, 400, 400, 400, 400, 400, 400, 400, 400]
ulb count: [5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000]
Files already downloaded and verified
unlabeled data number: 50000, labeled data number 3000
Create train and test data loaders
[!] data loader keys: dict_keys(['train_lb', 'cali', 'train_ulb', 'eval'])
_IncompatibleKeys(missing_keys=['head.weight', 'head.bias'], unexpected_keys=[])
Create optimizer and scheduler


In [5]:
len(algorithm.loader_dict['eval']), len(algorithm.loader_dict['train_lb']), len(algorithm.loader_dict['train_ulb'])

(625, 5000, 5000)

In [4]:
len(algorithm.loader_dict['cali'])

63

In [6]:
type(algorithm)

semilearn.algorithms.cpmatch.cpmatch.CpMatch

## Step 3: create dataset

In [13]:
dataset_dict = get_dataset(config, config.algorithm, config.dataset, config.num_labels, config.num_classes, data_dir=config.data_dir, include_lb_to_ulb=config.include_lb_to_ulb)
train_lb_loader = get_data_loader(config, dataset_dict['train_lb'], config.batch_size)
train_ulb_loader = get_data_loader(config, dataset_dict['train_ulb'], int(config.batch_size * config.uratio))
eval_loader = get_data_loader(config, dataset_dict['eval'], config.eval_batch_size)

Files already downloaded and verified
lb count: [4, 4, 4, 4, 4, 4, 4, 4, 4, 4]
ulb count: [5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000]
Files already downloaded and verified


## Step 4: train

In [4]:
trainer = Trainer(config, algorithm)

## verison 1

In [None]:
# trainer.fit(train_lb_loader, train_ulb_loader, eval_loader)
trainer.fit()

## Step 5: evaluate

In [5]:
import torch
trainer.algorithm.load_model('./saved_models/fixmatch/model_best.pth')

Model loaded


{'model': OrderedDict([('cls_token',
               tensor([[[ 1.0535e+00,  1.2308e+00,  5.5327e-01, -1.9716e+00, -1.0528e+00,
                         -4.0919e-01,  9.3155e-01, -5.5010e-02,  2.6703e+00,  2.7223e-01,
                          1.9811e+00, -2.9866e-01,  6.8813e-01,  7.2424e-02, -1.8258e-01,
                          1.7247e+00, -4.2164e-01, -1.6423e+00,  2.6290e+00, -1.0211e-01,
                         -2.1442e+00,  1.4471e+00,  2.2006e+00, -1.9177e+00,  1.0978e+00,
                          5.2692e-01, -1.5974e+00,  1.0956e+00, -9.3674e-02, -1.3778e+00,
                         -1.8121e+00,  1.0708e+00,  8.1456e-01, -7.4819e-01, -3.4974e-01,
                         -2.2685e+00,  5.3506e-02, -7.5320e-02,  8.7170e-01,  3.7337e-01,
                         -5.8739e-01,  2.0100e-02,  1.4791e+00,  4.1806e-01, -4.7033e-02,
                          7.1206e-01,  8.2810e-01, -1.0441e+00,  1.5621e+00,  1.0640e+00,
                          3.7726e-01,  9.7860e-02,  1.4352e+00,

# CpMatch

## CIFAR-10 40 82.09%

In [14]:
trainer.evaluate(algorithm.loader_dict["eval"])

  _warn_prf(average, modifier, msg_start, len(result))
[2023-10-27 08:11:35,392 INFO] confusion matrix
[2023-10-27 08:11:35,393 INFO] [[0.919 0.008 0.006 0.003 0.002 0.    0.    0.    0.047 0.015]
 [0.    0.985 0.    0.    0.    0.    0.    0.    0.001 0.014]
 [0.027 0.001 0.783 0.011 0.145 0.009 0.019 0.    0.003 0.002]
 [0.005 0.001 0.009 0.888 0.024 0.037 0.032 0.    0.003 0.001]
 [0.    0.    0.004 0.019 0.961 0.    0.015 0.    0.001 0.   ]
 [0.001 0.    0.014 0.137 0.08  0.751 0.017 0.    0.    0.   ]
 [0.003 0.    0.005 0.001 0.001 0.    0.989 0.    0.001 0.   ]
 [0.004 0.    0.006 0.025 0.954 0.006 0.002 0.    0.002 0.001]
 [0.024 0.005 0.    0.    0.    0.    0.    0.    0.961 0.01 ]
 [0.002 0.023 0.    0.    0.    0.    0.    0.    0.003 0.972]]
[2023-10-27 08:11:35,394 INFO] evaluation metric
[2023-10-27 08:11:35,394 INFO] acc: 0.8209
[2023-10-27 08:11:35,395 INFO] precision: 0.7859
[2023-10-27 08:11:35,395 INFO] recall: 0.8209
[2023-10-27 08:11:35,396 INFO] f1: 0.7918


{'acc': 0.8209,
 'precision': 0.7859351377635341,
 'recall': 0.8209,
 'f1': 0.791805644051878}

## CIFAR-10 4000 97.22% 61mins

In [None]:
# trainer.fit(train_lb_loader, train_ulb_loader, eval_loader)
trainer.fit()

Epoch: 0
Failed control. Cal Error:90.00%, min risk:89.99%, alpha:90.00, threshold:0.95
Cal Error:90.00%, min risk:86.02%, alpha:88.01, threshold:0.66
Cal Error:90.08%, min risk:84.91%, alpha:87.49, threshold:0.61
Cal Error:90.00%, min risk:89.39%, alpha:89.69, threshold:0.58
Cal Error:90.00%, min risk:89.47%, alpha:89.74, threshold:0.54
Cal Error:90.00%, min risk:80.00%, alpha:85.00, threshold:0.51
Cal Error:90.00%, min risk:80.00%, alpha:85.00, threshold:0.48
Cal Error:90.25%, min risk:90.00%, alpha:90.12, threshold:0.44
Cal Error:90.08%, min risk:82.98%, alpha:86.53, threshold:0.39
Cal Error:89.83%, min risk:85.19%, alpha:87.51, threshold:0.37
Cal Error:91.17%, min risk:76.92%, alpha:84.04, threshold:0.34
Cal Error:90.83%, min risk:70.00%, alpha:80.42, threshold:0.33
Cal Error:91.00%, min risk:89.29%, alpha:90.14, threshold:0.30
Cal Error:90.92%, min risk:86.84%, alpha:88.88, threshold:0.29
Cal Error:90.25%, min risk:72.73%, alpha:81.49, threshold:0.30
Cal Error:89.75%, min risk:75.

  _warn_prf(average, modifier, msg_start, len(result))


confusion matrix:
[[0.085 0.    0.    0.859 0.    0.    0.    0.    0.    0.056]
 [0.    0.    0.    0.813 0.    0.    0.    0.    0.    0.187]
 [0.001 0.    0.    0.984 0.    0.    0.    0.    0.    0.015]
 [0.    0.    0.    0.997 0.    0.001 0.    0.    0.    0.002]
 [0.    0.    0.    1.    0.    0.    0.    0.    0.    0.   ]
 [0.    0.    0.    0.998 0.    0.001 0.    0.    0.    0.001]
 [0.    0.    0.    0.996 0.    0.001 0.    0.    0.    0.003]
 [0.001 0.    0.    0.974 0.    0.004 0.    0.017 0.    0.004]
 [0.001 0.    0.    0.938 0.    0.    0.    0.    0.    0.061]
 [0.    0.    0.    0.068 0.    0.    0.    0.    0.    0.932]]
model saved: ./saved_models/fixmatch/latest_model.pth
model saved: ./saved_models/fixmatch/model_best.pth
500 iteration, USE_EMA: True, train/sup_loss: 0.1530, train/unsup_loss: 0.3087, train/total_loss: 0.4617, train/util_ratio: 0.9688, train/run_time: 1.1391, eval/loss: 4.0727, eval/top-1-acc: 0.2032, eval/balanced_acc: 0.2032, eval/precision: 0.2

  _warn_prf(average, modifier, msg_start, len(result))


confusion matrix:
[[0.958 0.    0.    0.014 0.    0.001 0.    0.    0.    0.027]
 [0.007 0.358 0.    0.045 0.    0.    0.    0.002 0.    0.588]
 [0.08  0.    0.063 0.605 0.048 0.086 0.    0.11  0.    0.008]
 [0.002 0.    0.    0.97  0.    0.026 0.    0.001 0.    0.001]
 [0.001 0.    0.    0.077 0.828 0.009 0.    0.085 0.    0.   ]
 [0.    0.    0.    0.14  0.    0.85  0.    0.01  0.    0.   ]
 [0.01  0.    0.    0.933 0.004 0.05  0.    0.    0.    0.003]
 [0.008 0.    0.    0.032 0.    0.054 0.    0.904 0.    0.002]
 [0.312 0.    0.    0.367 0.    0.005 0.    0.001 0.001 0.314]
 [0.002 0.    0.    0.001 0.    0.    0.    0.    0.    0.997]]
model saved: ./saved_models/fixmatch/latest_model.pth
model saved: ./saved_models/fixmatch/model_best.pth
1000 iteration, USE_EMA: True, train/sup_loss: 0.2333, train/unsup_loss: 0.3827, train/total_loss: 0.6159, train/util_ratio: 0.9062, train/run_time: 1.0734, eval/loss: 1.4586, eval/top-1-acc: 0.5929, eval/balanced_acc: 0.5929, eval/precision: 0.

[2023-10-27 23:08:16,399 INFO] confusion matrix
[2023-10-27 23:08:16,400 INFO] [[0.974 0.001 0.    0.    0.    0.    0.    0.    0.019 0.006]
 [0.    0.991 0.    0.    0.    0.    0.    0.    0.    0.009]
 [0.016 0.    0.965 0.005 0.005 0.005 0.004 0.    0.    0.   ]
 [0.001 0.    0.001 0.936 0.003 0.051 0.004 0.001 0.002 0.001]
 [0.    0.    0.003 0.005 0.969 0.004 0.003 0.015 0.001 0.   ]
 [0.    0.    0.001 0.02  0.004 0.969 0.001 0.005 0.    0.   ]
 [0.004 0.    0.001 0.003 0.    0.001 0.99  0.    0.001 0.   ]
 [0.006 0.    0.002 0.002 0.005 0.01  0.    0.975 0.    0.   ]
 [0.007 0.004 0.001 0.    0.    0.    0.    0.    0.983 0.005]
 [0.003 0.026 0.    0.    0.    0.    0.    0.    0.001 0.97 ]]
[2023-10-27 23:08:16,402 INFO] evaluation metric
[2023-10-27 23:08:16,402 INFO] acc: 0.9722
[2023-10-27 23:08:16,402 INFO] precision: 0.9724
[2023-10-27 23:08:16,403 INFO] recall: 0.9722
[2023-10-27 23:08:16,403 INFO] f1: 0.9722
[2023-10-27 23:08:16,711 INFO] Best acc 0.9722 at epoch 0
[20

model saved: ./saved_models/fixmatch/latest_model.pth
model saved: ./saved_models/fixmatch/model_best.pth


## CIFAR-10 4000 97.00% 21mins NEW

In [5]:
# trainer.fit(train_lb_loader, train_ulb_loader, eval_loader)
trainer.fit()

Epoch: 0


50 0
Cal Error:90.00%, min risk:84.30%, alpha:87.15, threshold:0.82
50 49
Cal Error:77.20%, min risk:40.91%, alpha:59.05, threshold:0.34
50 iteration USE_EMA: True, train/sup_loss: 2.5132, train/unsup_loss: 0.1072, train/total_loss: 2.6204, train/util_ratio: 0.0625, train/run_time: 0.9249, lr: 0.0005, train/prefetch_time: 0.8318 
50 99
Cal Error:46.20%, min risk:0.00%, alpha:23.10, threshold:0.27
100 iteration USE_EMA: True, train/sup_loss: 1.6636, train/unsup_loss: 0.4285, train/total_loss: 2.0920, train/util_ratio: 0.2812, train/run_time: 1.1018, lr: 0.0005, train/prefetch_time: 0.9140 
50 149
Cal Error:29.60%, min risk:0.00%, alpha:14.80, threshold:0.89
150 iteration USE_EMA: True, train/sup_loss: 1.0657, train/unsup_loss: 0.0000, train/total_loss: 1.0657, train/util_ratio: 0.0000, train/run_time: 1.5390, lr: 0.0005, train/prefetch_time: 1.3036 
50 199
Cal Error:15.40%, min risk:0.00%, alpha:7.70, threshold:0.96
200 iteration USE_EMA: True, train/sup_loss: 0.7319, train/unsup_loss: 

  _warn_prf(average, modifier, msg_start, len(result))


confusion matrix:
[[0.    0.996 0.    0.    0.    0.    0.    0.    0.    0.004]
 [0.    1.    0.    0.    0.    0.    0.    0.    0.    0.   ]
 [0.    0.999 0.    0.    0.    0.    0.    0.    0.    0.001]
 [0.    1.    0.    0.    0.    0.    0.    0.    0.    0.   ]
 [0.    1.    0.    0.    0.    0.    0.    0.    0.    0.   ]
 [0.    0.996 0.    0.    0.    0.004 0.    0.    0.    0.   ]
 [0.    1.    0.    0.    0.    0.    0.    0.    0.    0.   ]
 [0.    0.999 0.    0.    0.    0.    0.    0.    0.    0.001]
 [0.    1.    0.    0.    0.    0.    0.    0.    0.    0.   ]
 [0.    1.    0.    0.    0.    0.    0.    0.    0.    0.   ]]
model saved: ./saved_models/cpmatch/latest_model.pth
model saved: ./saved_models/cpmatch/model_best.pth
500 iteration, USE_EMA: True, train/sup_loss: 0.1447, train/unsup_loss: 0.2598, train/total_loss: 0.4045, train/util_ratio: 0.9062, train/run_time: 1.4775, eval/loss: 3.4857, eval/top-1-acc: 0.1004, eval/balanced_acc: 0.1004, eval/precision: 0.110

[2023-10-29 03:00:48,791 INFO] confusion matrix
[2023-10-29 03:00:48,793 INFO] [[0.965 0.002 0.001 0.001 0.    0.001 0.    0.    0.023 0.007]
 [0.    0.991 0.    0.    0.    0.    0.    0.    0.001 0.008]
 [0.02  0.    0.958 0.004 0.004 0.006 0.007 0.    0.001 0.   ]
 [0.001 0.002 0.002 0.934 0.003 0.045 0.01  0.    0.002 0.001]
 [0.    0.    0.004 0.002 0.969 0.003 0.006 0.016 0.    0.   ]
 [0.    0.    0.001 0.012 0.004 0.976 0.001 0.006 0.    0.   ]
 [0.003 0.    0.002 0.002 0.    0.    0.992 0.    0.001 0.   ]
 [0.004 0.    0.004 0.002 0.008 0.014 0.    0.968 0.    0.   ]
 [0.007 0.008 0.    0.    0.    0.    0.    0.    0.982 0.003]
 [0.002 0.032 0.    0.    0.    0.    0.    0.    0.001 0.965]]
[2023-10-29 03:00:48,794 INFO] evaluation metric
[2023-10-29 03:00:48,794 INFO] acc: 0.9700
[2023-10-29 03:00:48,795 INFO] precision: 0.9703
[2023-10-29 03:00:48,795 INFO] recall: 0.9700
[2023-10-29 03:00:48,796 INFO] f1: 0.9700
[2023-10-29 03:00:48,982 INFO] Best acc 0.9703 at epoch 0
[20

model saved: ./saved_models/cpmatch/latest_model.pth


In [6]:
trainer.evaluate(algorithm.loader_dict["eval"])

[2023-10-29 03:11:46,408 INFO] confusion matrix
[2023-10-29 03:11:46,409 INFO] [[0.965 0.002 0.001 0.001 0.    0.001 0.    0.    0.023 0.007]
 [0.    0.991 0.    0.    0.    0.    0.    0.    0.001 0.008]
 [0.02  0.    0.958 0.004 0.004 0.006 0.007 0.    0.001 0.   ]
 [0.001 0.002 0.002 0.934 0.003 0.045 0.01  0.    0.002 0.001]
 [0.    0.    0.004 0.002 0.969 0.003 0.006 0.016 0.    0.   ]
 [0.    0.    0.001 0.012 0.004 0.976 0.001 0.006 0.    0.   ]
 [0.003 0.    0.002 0.002 0.    0.    0.992 0.    0.001 0.   ]
 [0.004 0.    0.004 0.002 0.008 0.014 0.    0.968 0.    0.   ]
 [0.007 0.008 0.    0.    0.    0.    0.    0.    0.982 0.003]
 [0.002 0.032 0.    0.    0.    0.    0.    0.    0.001 0.965]]
[2023-10-29 03:11:46,410 INFO] evaluation metric
[2023-10-29 03:11:46,410 INFO] acc: 0.9700
[2023-10-29 03:11:46,411 INFO] precision: 0.9703
[2023-10-29 03:11:46,411 INFO] recall: 0.9700
[2023-10-29 03:11:46,411 INFO] f1: 0.9700


{'acc': 0.97,
 'precision': 0.9703027884251654,
 'recall': 0.97,
 'f1': 0.9699890828732898}

# FixMatch

## CIFAR-10-40 95.14%

In [11]:
trainer.evaluate(eval_loader)

[2023-10-24 22:21:23,749 INFO] confusion matrix
[2023-10-24 22:21:23,750 INFO] [[0.97  0.001 0.001 0.001 0.    0.    0.001 0.    0.016 0.01 ]
 [0.    0.986 0.    0.    0.    0.    0.    0.    0.    0.014]
 [0.043 0.    0.909 0.004 0.021 0.009 0.01  0.004 0.    0.   ]
 [0.003 0.002 0.005 0.908 0.009 0.049 0.019 0.002 0.002 0.001]
 [0.    0.    0.005 0.006 0.949 0.    0.01  0.029 0.001 0.   ]
 [0.002 0.    0.009 0.049 0.007 0.91  0.002 0.021 0.    0.   ]
 [0.003 0.    0.003 0.002 0.    0.001 0.989 0.    0.001 0.001]
 [0.008 0.    0.008 0.003 0.007 0.025 0.    0.948 0.    0.001]
 [0.013 0.005 0.001 0.    0.    0.    0.    0.    0.978 0.003]
 [0.007 0.023 0.    0.    0.    0.    0.    0.    0.003 0.967]]
[2023-10-24 22:21:23,751 INFO] evaluation metric
[2023-10-24 22:21:23,752 INFO] acc: 0.9514
[2023-10-24 22:21:23,752 INFO] precision: 0.9515
[2023-10-24 22:21:23,753 INFO] recall: 0.9514
[2023-10-24 22:21:23,753 INFO] f1: 0.9513


{'acc': 0.9514,
 'precision': 0.9514999747717987,
 'recall': 0.9514000000000001,
 'f1': 0.951259687624711}

## CIFAR-10-4000 97.33% 15mins

In [5]:
trainer.fit()

Epoch: 0
50 iteration USE_EMA: True, train/sup_loss: 2.5638, train/unsup_loss: 0.0000, train/total_loss: 2.5638, train/util_ratio: 0.0000, train/run_time: 0.1378, lr: 0.0005, train/prefetch_time: 0.0026 
100 iteration USE_EMA: True, train/sup_loss: 1.6251, train/unsup_loss: 0.0000, train/total_loss: 1.6251, train/util_ratio: 0.0000, train/run_time: 0.1428, lr: 0.0005, train/prefetch_time: 0.0027 
150 iteration USE_EMA: True, train/sup_loss: 1.0265, train/unsup_loss: 0.0000, train/total_loss: 1.0265, train/util_ratio: 0.0000, train/run_time: 0.1414, lr: 0.0005, train/prefetch_time: 0.0040 
200 iteration USE_EMA: True, train/sup_loss: 0.5550, train/unsup_loss: 0.0000, train/total_loss: 0.5550, train/util_ratio: 0.0000, train/run_time: 0.1364, lr: 0.0005, train/prefetch_time: 0.0043 
250 iteration USE_EMA: True, train/sup_loss: 0.6040, train/unsup_loss: 0.0896, train/total_loss: 0.6936, train/util_ratio: 0.1562, train/run_time: 0.1379, lr: 0.0005, train/prefetch_time: 0.0028 
300 iteratio

  _warn_prf(average, modifier, msg_start, len(result))


confusion matrix:
[[0.791 0.    0.    0.    0.209 0.    0.    0.    0.    0.   ]
 [0.64  0.    0.    0.    0.36  0.    0.    0.    0.    0.   ]
 [0.054 0.    0.    0.    0.946 0.    0.    0.    0.    0.   ]
 [0.113 0.    0.    0.    0.887 0.    0.    0.    0.    0.   ]
 [0.006 0.    0.    0.    0.994 0.    0.    0.    0.    0.   ]
 [0.015 0.    0.    0.    0.985 0.    0.    0.    0.    0.   ]
 [0.034 0.    0.    0.    0.966 0.    0.    0.    0.    0.   ]
 [0.011 0.    0.    0.    0.989 0.    0.    0.    0.    0.   ]
 [0.562 0.    0.    0.    0.438 0.    0.    0.    0.    0.   ]
 [0.155 0.    0.    0.    0.845 0.    0.    0.    0.    0.   ]]
model saved: ./saved_models/fixmatch/latest_model.pth
model saved: ./saved_models/fixmatch/model_best.pth
500 iteration, USE_EMA: True, train/sup_loss: 0.0884, train/unsup_loss: 0.2164, train/total_loss: 0.3048, train/util_ratio: 0.5625, train/run_time: 0.1463, eval/loss: 4.4787, eval/top-1-acc: 0.1785, eval/balanced_acc: 0.1785, eval/precision: 0.0

  _warn_prf(average, modifier, msg_start, len(result))


confusion matrix:
[[0.995 0.    0.    0.    0.005 0.    0.    0.    0.    0.   ]
 [0.097 0.88  0.    0.    0.022 0.    0.    0.    0.    0.001]
 [0.055 0.    0.198 0.002 0.743 0.002 0.    0.    0.    0.   ]
 [0.063 0.001 0.001 0.453 0.478 0.004 0.    0.    0.    0.   ]
 [0.    0.    0.    0.    1.    0.    0.    0.    0.    0.   ]
 [0.007 0.    0.    0.007 0.748 0.238 0.    0.    0.    0.   ]
 [0.028 0.    0.    0.004 0.968 0.    0.    0.    0.    0.   ]
 [0.016 0.    0.    0.    0.982 0.    0.    0.002 0.    0.   ]
 [0.697 0.    0.    0.    0.022 0.    0.    0.    0.281 0.   ]
 [0.514 0.024 0.    0.    0.087 0.    0.    0.    0.    0.375]]
model saved: ./saved_models/fixmatch/latest_model.pth
model saved: ./saved_models/fixmatch/model_best.pth
1000 iteration, USE_EMA: True, train/sup_loss: 0.5831, train/unsup_loss: 0.2215, train/total_loss: 0.8046, train/util_ratio: 0.7500, train/run_time: 0.1337, eval/loss: 1.3900, eval/top-1-acc: 0.4422, eval/balanced_acc: 0.4422, eval/precision: 0.

[2023-10-27 23:33:17,131 INFO] confusion matrix
[2023-10-27 23:33:17,132 INFO] [[0.979 0.001 0.002 0.001 0.    0.    0.    0.    0.011 0.006]
 [0.    0.985 0.    0.    0.    0.    0.    0.    0.001 0.014]
 [0.014 0.    0.968 0.006 0.003 0.006 0.003 0.    0.    0.   ]
 [0.001 0.002 0.002 0.954 0.003 0.028 0.006 0.001 0.002 0.001]
 [0.    0.    0.007 0.005 0.964 0.004 0.003 0.017 0.    0.   ]
 [0.    0.    0.001 0.027 0.004 0.961 0.001 0.006 0.    0.   ]
 [0.002 0.    0.001 0.002 0.    0.001 0.993 0.    0.001 0.   ]
 [0.003 0.    0.003 0.003 0.009 0.006 0.    0.974 0.    0.002]
 [0.011 0.003 0.    0.    0.    0.    0.    0.    0.983 0.003]
 [0.004 0.019 0.    0.    0.    0.    0.    0.    0.001 0.976]]
[2023-10-27 23:33:17,134 INFO] evaluation metric
[2023-10-27 23:33:17,134 INFO] acc: 0.9737
[2023-10-27 23:33:17,134 INFO] precision: 0.9737
[2023-10-27 23:33:17,135 INFO] recall: 0.9737
[2023-10-27 23:33:17,135 INFO] f1: 0.9737
[2023-10-27 23:33:17,473 INFO] Best acc 0.9737 at epoch 0
[20

model saved: ./saved_models/fixmatch/latest_model.pth
model saved: ./saved_models/fixmatch/model_best.pth


In [6]:
trainer.evaluate(algorithm.loader_dict["eval"])

[2023-10-27 23:33:37,931 INFO] confusion matrix
[2023-10-27 23:33:37,932 INFO] [[0.979 0.001 0.002 0.001 0.    0.    0.    0.    0.011 0.006]
 [0.    0.985 0.    0.    0.    0.    0.    0.    0.001 0.014]
 [0.014 0.    0.968 0.006 0.003 0.006 0.003 0.    0.    0.   ]
 [0.001 0.002 0.002 0.954 0.003 0.028 0.006 0.001 0.002 0.001]
 [0.    0.    0.007 0.005 0.964 0.004 0.003 0.017 0.    0.   ]
 [0.    0.    0.001 0.027 0.004 0.961 0.001 0.006 0.    0.   ]
 [0.002 0.    0.001 0.002 0.    0.001 0.993 0.    0.001 0.   ]
 [0.003 0.    0.003 0.003 0.009 0.006 0.    0.974 0.    0.002]
 [0.011 0.003 0.    0.    0.    0.    0.    0.    0.983 0.003]
 [0.004 0.019 0.    0.    0.    0.    0.    0.    0.001 0.976]]
[2023-10-27 23:33:37,934 INFO] evaluation metric
[2023-10-27 23:33:37,934 INFO] acc: 0.9737
[2023-10-27 23:33:37,934 INFO] precision: 0.9737
[2023-10-27 23:33:37,935 INFO] recall: 0.9737
[2023-10-27 23:33:37,935 INFO] f1: 0.9737


{'acc': 0.9737,
 'precision': 0.9737387543003269,
 'recall': 0.9737000000000002,
 'f1': 0.9736967021957128}

# Step 6: predict

In [7]:
y_pred, y_logits = trainer.predict(eval_loader)