# Getting started with our ultimate beginner guide!

## This tutorial will walk you through the basics of using the `usb` lighting package. Let's get started by training a FixMatch model on CIFAR-10!

In [1]:
import sys
sys.path.append('../')

from semilearn import get_dataset, get_data_loader, get_net_builder, get_algorithm, get_config, Trainer
%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


## Step 1: define configs and create config

In [2]:
config = {
    # 'algorithm': 'fixmatch',
    'algorithm': 'cpmatch',
    'net': 'vit_tiny_patch2_32',
    'use_pretrain': True, 
    'pretrain_path': 'https://github.com/microsoft/Semi-supervised-learning/releases/download/v.0.0.0/vit_tiny_patch2_32_mlp_im_1k_32.pth',

    # optimization configs
    'epoch': 1,  
    'num_train_iter': 5000,  
    'num_eval_iter': 500,  
    'num_log_iter': 50,  
    'optim': 'AdamW',
    'lr': 5e-4,
    'layer_decay': 0.5,
    'batch_size': 16,
    'eval_batch_size': 16,


    # dataset configs
    'dataset': 'cifar10',
    'num_labels': 4000,
    'num_classes': 10,
    'img_size': 32,
    'crop_ratio': 0.875,
    'data_dir': '/data/datasets',
    'ulb_samples_per_class': None,

    # algorithm specific configs
    'hard_label': True,
    'uratio': 2,
    'ulb_loss_ratio': 1.0,

    # device configs
    'gpu': 1,
    'world_size': 1,
    'distributed': False,
    "num_workers": 2,
}
config = get_config(config)

## Step 2: create model and specify algorithm

In [3]:
algorithm = get_algorithm(config,  get_net_builder(config.net, from_name=False), tb_log=None, logger=None)

Files already downloaded and verified
lb count: [400, 400, 400, 400, 400, 400, 400, 400, 400, 400]
ulb count: [5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000]
Files already downloaded and verified
unlabeled data number: 50000, labeled data number 2800
Create train and test data loaders
[!] data loader keys: dict_keys(['train_lb', 'cali', 'train_ulb', 'eval'])
_IncompatibleKeys(missing_keys=['head.weight', 'head.bias'], unexpected_keys=[])
Create optimizer and scheduler


In [None]:
len(algorithm.loader_dict['eval'])

5000

In [31]:
len(algorithm.loader_dict['cali'])

75

In [4]:
type(algorithm)

semilearn.algorithms.cpmatch.cpmatch.CpMatch

## Step 3: create dataset

In [13]:
dataset_dict = get_dataset(config, config.algorithm, config.dataset, config.num_labels, config.num_classes, data_dir=config.data_dir, include_lb_to_ulb=config.include_lb_to_ulb)
train_lb_loader = get_data_loader(config, dataset_dict['train_lb'], config.batch_size)
train_ulb_loader = get_data_loader(config, dataset_dict['train_ulb'], int(config.batch_size * config.uratio))
eval_loader = get_data_loader(config, dataset_dict['eval'], config.eval_batch_size)

Files already downloaded and verified
lb count: [4, 4, 4, 4, 4, 4, 4, 4, 4, 4]
ulb count: [5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000]
Files already downloaded and verified


## Step 4: train

In [4]:
trainer = Trainer(config, algorithm)

In [6]:
# trainer.fit(train_lb_loader, train_ulb_loader, eval_loader)
trainer.fit()

Epoch: 0


OutOfMemoryError: CUDA out of memory. Tried to allocate 62.00 MiB. GPU 1 has a total capacty of 79.15 GiB of which 5.69 MiB is free. Process 1142479 has 36.51 GiB memory in use. Process 1143667 has 36.51 GiB memory in use. Including non-PyTorch memory, this process has 6.13 GiB memory in use. Of the allocated memory 4.40 GiB is allocated by PyTorch, and 121.81 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

## verison 1

In [11]:
# trainer.fit(train_lb_loader, train_ulb_loader, eval_loader)
trainer.fit()

Epoch: 0
Cal Error:58.33%, min risk:20.00%, alpha:39.17, threshold:0.860
Threshold, 0.8602000000000001
Cal Error:66.67%, min risk:40.00%, alpha:53.33, threshold:0.854
Threshold, 0.8538
Cal Error:66.67%, min risk:25.00%, alpha:45.83, threshold:0.883
Threshold, 0.8828
Cal Error:50.00%, min risk:25.00%, alpha:37.50, threshold:0.911
Threshold, 0.9114000000000001
Cal Error:58.33%, min risk:33.33%, alpha:45.83, threshold:0.895
Threshold, 0.8952
Cal Error:50.00%, min risk:40.00%, alpha:45.00, threshold:0.919
Threshold, 0.9186000000000001
Cal Error:58.33%, min risk:25.00%, alpha:41.67, threshold:0.904
Threshold, 0.9044000000000001
Cal Error:50.00%, min risk:40.00%, alpha:45.00, threshold:0.832
Threshold, 0.8318000000000001
Cal Error:50.00%, min risk:33.33%, alpha:41.67, threshold:0.872
Threshold, 0.8716
Cal Error:58.33%, min risk:0.00%, alpha:29.17, threshold:0.843
Threshold, 0.8430000000000001
Cal Error:50.00%, min risk:20.00%, alpha:35.00, threshold:0.949
Threshold, 0.9492
Cal Error:58.33%, 

  _warn_prf(average, modifier, msg_start, len(result))


confusion matrix:
[[0.975 0.002 0.001 0.002 0.    0.    0.    0.    0.005 0.015]
 [0.    0.971 0.    0.    0.    0.    0.    0.    0.001 0.028]
 [0.047 0.    0.79  0.007 0.129 0.02  0.002 0.    0.001 0.004]
 [0.007 0.001 0.038 0.765 0.053 0.117 0.007 0.    0.005 0.007]
 [0.004 0.    0.024 0.009 0.963 0.    0.    0.    0.    0.   ]
 [0.001 0.    0.017 0.055 0.084 0.842 0.    0.    0.    0.001]
 [0.01  0.    0.37  0.028 0.095 0.001 0.495 0.    0.001 0.   ]
 [0.006 0.    0.039 0.013 0.917 0.018 0.001 0.    0.001 0.005]
 [0.706 0.03  0.002 0.    0.    0.    0.    0.    0.173 0.089]
 [0.004 0.017 0.    0.    0.    0.    0.    0.    0.    0.979]]
model saved: ./saved_models/fixmatch/latest_model.pth
model saved: ./saved_models/fixmatch/model_best.pth
1000 iteration, USE_EMA: True, train/sup_loss: 0.0092, train/unsup_loss: 0.0345, train/total_loss: 0.0436, train/util_ratio: 0.4062, train/run_time: 0.4738, eval/loss: 1.3761, eval/top-1-acc: 0.6953, eval/balanced_acc: 0.6953, eval/precision: 0.

  _warn_prf(average, modifier, msg_start, len(result))


confusion matrix:
[[0.972 0.003 0.001 0.002 0.    0.    0.001 0.    0.006 0.015]
 [0.    0.978 0.    0.    0.    0.    0.    0.    0.    0.022]
 [0.044 0.    0.784 0.007 0.135 0.016 0.01  0.    0.001 0.003]
 [0.007 0.002 0.024 0.811 0.046 0.089 0.011 0.    0.004 0.006]
 [0.003 0.    0.007 0.01  0.976 0.    0.004 0.    0.    0.   ]
 [0.001 0.    0.015 0.063 0.086 0.833 0.001 0.    0.    0.001]
 [0.007 0.    0.104 0.018 0.03  0.    0.841 0.    0.    0.   ]
 [0.008 0.    0.017 0.015 0.941 0.013 0.002 0.    0.002 0.002]
 [0.654 0.035 0.001 0.    0.    0.    0.    0.    0.223 0.087]
 [0.004 0.017 0.    0.    0.    0.    0.    0.    0.    0.979]]
model saved: ./saved_models/fixmatch/latest_model.pth
model saved: ./saved_models/fixmatch/model_best.pth
1500 iteration, USE_EMA: True, train/sup_loss: 0.0174, train/unsup_loss: 0.1320, train/total_loss: 0.1494, train/util_ratio: 0.6562, train/run_time: 0.4644, eval/loss: 1.3396, eval/top-1-acc: 0.7397, eval/balanced_acc: 0.7397, eval/precision: 0.

  _warn_prf(average, modifier, msg_start, len(result))


confusion matrix:
[[0.971 0.003 0.001 0.002 0.    0.    0.001 0.    0.006 0.016]
 [0.    0.979 0.    0.    0.    0.    0.    0.    0.    0.021]
 [0.043 0.    0.772 0.007 0.147 0.015 0.013 0.    0.    0.003]
 [0.006 0.002 0.019 0.823 0.043 0.081 0.018 0.    0.004 0.004]
 [0.002 0.    0.005 0.015 0.973 0.    0.005 0.    0.    0.   ]
 [0.001 0.    0.011 0.068 0.082 0.834 0.004 0.    0.    0.   ]
 [0.006 0.    0.037 0.011 0.006 0.    0.939 0.    0.001 0.   ]
 [0.007 0.    0.009 0.013 0.956 0.009 0.002 0.    0.002 0.002]
 [0.597 0.035 0.    0.    0.    0.    0.    0.    0.287 0.081]
 [0.003 0.017 0.    0.    0.    0.    0.    0.    0.001 0.979]]
model saved: ./saved_models/fixmatch/latest_model.pth
model saved: ./saved_models/fixmatch/model_best.pth
2000 iteration, USE_EMA: True, train/sup_loss: 0.0006, train/unsup_loss: 0.0005, train/total_loss: 0.0010, train/util_ratio: 0.1875, train/run_time: 0.4383, eval/loss: 1.3416, eval/top-1-acc: 0.7557, eval/balanced_acc: 0.7557, eval/precision: 0.

  _warn_prf(average, modifier, msg_start, len(result))


confusion matrix:
[[0.965 0.002 0.001 0.002 0.    0.    0.001 0.    0.013 0.016]
 [0.    0.981 0.    0.    0.    0.    0.    0.    0.    0.019]
 [0.04  0.    0.764 0.009 0.147 0.018 0.018 0.    0.001 0.003]
 [0.006 0.002 0.013 0.815 0.041 0.086 0.027 0.    0.006 0.004]
 [0.001 0.    0.004 0.012 0.97  0.    0.013 0.    0.    0.   ]
 [0.001 0.    0.01  0.062 0.073 0.848 0.006 0.    0.    0.   ]
 [0.004 0.    0.013 0.007 0.003 0.    0.972 0.    0.001 0.   ]
 [0.005 0.    0.008 0.013 0.953 0.013 0.004 0.    0.002 0.002]
 [0.444 0.01  0.    0.    0.    0.    0.    0.    0.483 0.063]
 [0.003 0.018 0.    0.    0.    0.    0.    0.    0.001 0.978]]
model saved: ./saved_models/fixmatch/latest_model.pth
model saved: ./saved_models/fixmatch/model_best.pth
2500 iteration, USE_EMA: True, train/sup_loss: 0.0063, train/unsup_loss: 0.0246, train/total_loss: 0.0309, train/util_ratio: 0.5312, train/run_time: 0.4373, eval/loss: 1.2861, eval/top-1-acc: 0.7776, eval/balanced_acc: 0.7776, eval/precision: 0.

  _warn_prf(average, modifier, msg_start, len(result))


confusion matrix:
[[0.952 0.004 0.002 0.002 0.001 0.    0.001 0.    0.022 0.016]
 [0.    0.979 0.    0.    0.    0.    0.    0.    0.    0.021]
 [0.037 0.    0.763 0.01  0.15  0.016 0.02  0.    0.001 0.003]
 [0.006 0.001 0.012 0.826 0.038 0.077 0.029 0.    0.008 0.003]
 [0.    0.    0.002 0.014 0.969 0.    0.014 0.    0.001 0.   ]
 [0.001 0.    0.01  0.067 0.07  0.847 0.005 0.    0.    0.   ]
 [0.004 0.    0.007 0.003 0.002 0.    0.983 0.    0.001 0.   ]
 [0.004 0.    0.008 0.014 0.956 0.011 0.003 0.    0.003 0.001]
 [0.185 0.009 0.    0.    0.    0.    0.    0.    0.762 0.044]
 [0.003 0.016 0.    0.    0.    0.    0.    0.    0.002 0.979]]
model saved: ./saved_models/fixmatch/latest_model.pth
model saved: ./saved_models/fixmatch/model_best.pth
3000 iteration, USE_EMA: True, train/sup_loss: 0.0043, train/unsup_loss: 0.0033, train/total_loss: 0.0076, train/util_ratio: 0.3750, train/run_time: 0.4410, eval/loss: 1.2508, eval/top-1-acc: 0.8060, eval/balanced_acc: 0.8060, eval/precision: 0.

  _warn_prf(average, modifier, msg_start, len(result))


confusion matrix:
[[0.939 0.006 0.003 0.002 0.001 0.    0.002 0.    0.031 0.016]
 [0.    0.98  0.    0.    0.    0.    0.    0.    0.001 0.019]
 [0.033 0.    0.768 0.009 0.151 0.013 0.022 0.    0.001 0.003]
 [0.006 0.001 0.012 0.837 0.034 0.072 0.027 0.    0.008 0.003]
 [0.    0.    0.002 0.014 0.969 0.    0.014 0.    0.001 0.   ]
 [0.001 0.    0.014 0.073 0.066 0.842 0.004 0.    0.    0.   ]
 [0.004 0.    0.005 0.002 0.001 0.    0.987 0.    0.001 0.   ]
 [0.004 0.    0.008 0.017 0.955 0.01  0.002 0.    0.003 0.001]
 [0.077 0.006 0.    0.    0.    0.    0.    0.    0.887 0.03 ]
 [0.002 0.016 0.    0.    0.    0.    0.    0.    0.003 0.979]]
model saved: ./saved_models/fixmatch/latest_model.pth
model saved: ./saved_models/fixmatch/model_best.pth
3500 iteration, USE_EMA: True, train/sup_loss: 0.0002, train/unsup_loss: 0.1023, train/total_loss: 0.1024, train/util_ratio: 0.3125, train/run_time: 0.4995, eval/loss: 1.2610, eval/top-1-acc: 0.8188, eval/balanced_acc: 0.8188, eval/precision: 0.

  _warn_prf(average, modifier, msg_start, len(result))


confusion matrix:
[[0.937 0.006 0.004 0.002 0.001 0.    0.001 0.    0.033 0.016]
 [0.    0.98  0.    0.    0.    0.    0.    0.    0.001 0.019]
 [0.032 0.    0.772 0.009 0.152 0.013 0.018 0.    0.001 0.003]
 [0.006 0.001 0.013 0.849 0.034 0.059 0.032 0.    0.005 0.001]
 [0.    0.    0.002 0.014 0.969 0.    0.014 0.    0.001 0.   ]
 [0.001 0.    0.013 0.084 0.075 0.82  0.007 0.    0.    0.   ]
 [0.004 0.    0.003 0.001 0.001 0.    0.99  0.    0.001 0.   ]
 [0.004 0.    0.008 0.017 0.957 0.008 0.002 0.    0.002 0.002]
 [0.056 0.006 0.    0.    0.    0.    0.    0.    0.917 0.021]
 [0.002 0.02  0.    0.    0.    0.    0.    0.    0.003 0.975]]
model saved: ./saved_models/fixmatch/latest_model.pth
model saved: ./saved_models/fixmatch/model_best.pth
4000 iteration, USE_EMA: True, train/sup_loss: 0.0100, train/unsup_loss: 0.0856, train/total_loss: 0.0955, train/util_ratio: 0.5000, train/run_time: 0.5913, eval/loss: 1.2859, eval/top-1-acc: 0.8209, eval/balanced_acc: 0.8209, eval/precision: 0.

  _warn_prf(average, modifier, msg_start, len(result))


Cal Error:33.33%, min risk:14.29%, alpha:23.81, threshold:0.999
Threshold, 0.9994000000000001
Cal Error:50.00%, min risk:25.00%, alpha:37.50, threshold:1.000
Threshold, 0.9996
Cal Error:41.67%, min risk:20.00%, alpha:30.83, threshold:0.999
Threshold, 0.9988
Cal Error:33.33%, min risk:0.00%, alpha:16.67, threshold:0.999
Threshold, 0.9992000000000001
Cal Error:33.33%, min risk:12.50%, alpha:22.92, threshold:1.000
Threshold, 0.9996
Cal Error:33.33%, min risk:12.50%, alpha:22.92, threshold:1.000
Threshold, 0.9996
Cal Error:41.67%, min risk:14.29%, alpha:27.98, threshold:1.000
Threshold, 0.9996
Cal Error:33.33%, min risk:12.50%, alpha:22.92, threshold:1.000
Threshold, 0.9996
Cal Error:33.33%, min risk:20.00%, alpha:26.67, threshold:1.000
Threshold, 0.9996
Cal Error:41.67%, min risk:16.67%, alpha:29.17, threshold:0.999
Threshold, 0.9994000000000001
Cal Error:41.67%, min risk:12.50%, alpha:27.08, threshold:1.000
Threshold, 0.9998
Cal Error:41.67%, min risk:20.00%, alpha:30.83, threshold:0.991

  _warn_prf(average, modifier, msg_start, len(result))


confusion matrix:
[[0.932 0.007 0.005 0.002 0.002 0.    0.    0.    0.038 0.014]
 [0.    0.985 0.    0.    0.    0.    0.    0.    0.001 0.014]
 [0.031 0.001 0.765 0.012 0.156 0.011 0.02  0.    0.002 0.002]
 [0.005 0.001 0.009 0.873 0.028 0.047 0.033 0.    0.003 0.001]
 [0.    0.    0.002 0.016 0.966 0.    0.015 0.    0.001 0.   ]
 [0.001 0.    0.01  0.107 0.078 0.788 0.016 0.    0.    0.   ]
 [0.004 0.    0.002 0.001 0.001 0.    0.991 0.    0.001 0.   ]
 [0.004 0.    0.005 0.021 0.956 0.008 0.002 0.    0.002 0.002]
 [0.041 0.005 0.    0.    0.    0.    0.    0.    0.939 0.015]
 [0.002 0.023 0.    0.    0.    0.    0.    0.    0.003 0.972]]
model saved: ./saved_models/fixmatch/latest_model.pth
model saved: ./saved_models/fixmatch/model_best.pth
5000 iteration, USE_EMA: True, train/sup_loss: 0.0299, train/unsup_loss: 0.0600, train/total_loss: 0.0898, train/util_ratio: 0.3750, train/run_time: 0.5569, eval/loss: 1.2962, eval/top-1-acc: 0.8211, eval/balanced_acc: 0.8211, eval/precision: 0.

  _warn_prf(average, modifier, msg_start, len(result))
[2023-10-27 08:10:13,005 INFO] confusion matrix
[2023-10-27 08:10:13,006 INFO] [[0.919 0.008 0.006 0.003 0.002 0.    0.    0.    0.047 0.015]
 [0.    0.985 0.    0.    0.    0.    0.    0.    0.001 0.014]
 [0.027 0.001 0.783 0.011 0.145 0.009 0.019 0.    0.003 0.002]
 [0.005 0.001 0.009 0.888 0.024 0.037 0.032 0.    0.003 0.001]
 [0.    0.    0.004 0.019 0.961 0.    0.015 0.    0.001 0.   ]
 [0.001 0.    0.014 0.137 0.08  0.751 0.017 0.    0.    0.   ]
 [0.003 0.    0.005 0.001 0.001 0.    0.989 0.    0.001 0.   ]
 [0.004 0.    0.006 0.025 0.954 0.006 0.002 0.    0.002 0.001]
 [0.024 0.005 0.    0.    0.    0.    0.    0.    0.961 0.01 ]
 [0.002 0.023 0.    0.    0.    0.    0.    0.    0.003 0.972]]
[2023-10-27 08:10:13,008 INFO] evaluation metric
[2023-10-27 08:10:13,008 INFO] acc: 0.8209
[2023-10-27 08:10:13,008 INFO] precision: 0.7859
[2023-10-27 08:10:13,009 INFO] recall: 0.8209
[2023-10-27 08:10:13,009 INFO] f1: 0.7918
[2023-

model saved: ./saved_models/fixmatch/latest_model.pth


In [7]:
algorithm.loader_dict

{'train_lb': <torch.utils.data.dataloader.DataLoader at 0x7f6d59f91090>,
 'train_ulb': <torch.utils.data.dataloader.DataLoader at 0x7f6d5841ce90>,
 'eval': <torch.utils.data.dataloader.DataLoader at 0x7f6d3f700190>}

## Step 5: evaluate

In [10]:
import torch
trainer.algorithm.load_model('./saved_models/fixmatch/model_best.pth')

Model loaded


{'model': OrderedDict([('cls_token',
               tensor([[[ 1.0535e+00,  1.2308e+00,  5.5327e-01, -1.9716e+00, -1.0528e+00,
                         -4.0919e-01,  9.3156e-01, -5.5019e-02,  2.6703e+00,  2.7222e-01,
                          1.9811e+00, -2.9867e-01,  6.8812e-01,  7.2430e-02, -1.8258e-01,
                          1.7247e+00, -4.2164e-01, -1.6423e+00,  2.6290e+00, -1.0211e-01,
                         -2.1442e+00,  1.4471e+00,  2.2006e+00, -1.9177e+00,  1.0978e+00,
                          5.2692e-01, -1.5974e+00,  1.0956e+00, -9.3670e-02, -1.3778e+00,
                         -1.8121e+00,  1.0708e+00,  8.1457e-01, -7.4820e-01, -3.4974e-01,
                         -2.2685e+00,  5.3508e-02, -7.5312e-02,  8.7169e-01,  3.7337e-01,
                         -5.8740e-01,  2.0099e-02,  1.4791e+00,  4.1806e-01, -4.7039e-02,
                          7.1206e-01,  8.2810e-01, -1.0441e+00,  1.5621e+00,  1.0640e+00,
                          3.7726e-01,  9.7856e-02,  1.4352e+00,

# CpMatch

## version 1
lambda = np.linspace(0,1,5001)
gamma =0.5

In [14]:
trainer.evaluate(algorithm.loader_dict["eval"])

  _warn_prf(average, modifier, msg_start, len(result))
[2023-10-27 08:11:35,392 INFO] confusion matrix
[2023-10-27 08:11:35,393 INFO] [[0.919 0.008 0.006 0.003 0.002 0.    0.    0.    0.047 0.015]
 [0.    0.985 0.    0.    0.    0.    0.    0.    0.001 0.014]
 [0.027 0.001 0.783 0.011 0.145 0.009 0.019 0.    0.003 0.002]
 [0.005 0.001 0.009 0.888 0.024 0.037 0.032 0.    0.003 0.001]
 [0.    0.    0.004 0.019 0.961 0.    0.015 0.    0.001 0.   ]
 [0.001 0.    0.014 0.137 0.08  0.751 0.017 0.    0.    0.   ]
 [0.003 0.    0.005 0.001 0.001 0.    0.989 0.    0.001 0.   ]
 [0.004 0.    0.006 0.025 0.954 0.006 0.002 0.    0.002 0.001]
 [0.024 0.005 0.    0.    0.    0.    0.    0.    0.961 0.01 ]
 [0.002 0.023 0.    0.    0.    0.    0.    0.    0.003 0.972]]
[2023-10-27 08:11:35,394 INFO] evaluation metric
[2023-10-27 08:11:35,394 INFO] acc: 0.8209
[2023-10-27 08:11:35,395 INFO] precision: 0.7859
[2023-10-27 08:11:35,395 INFO] recall: 0.8209
[2023-10-27 08:11:35,396 INFO] f1: 0.7918


{'acc': 0.8209,
 'precision': 0.7859351377635341,
 'recall': 0.8209,
 'f1': 0.791805644051878}

# FixMatch

In [11]:
trainer.evaluate(eval_loader)

[2023-10-24 22:21:23,749 INFO] confusion matrix
[2023-10-24 22:21:23,750 INFO] [[0.97  0.001 0.001 0.001 0.    0.    0.001 0.    0.016 0.01 ]
 [0.    0.986 0.    0.    0.    0.    0.    0.    0.    0.014]
 [0.043 0.    0.909 0.004 0.021 0.009 0.01  0.004 0.    0.   ]
 [0.003 0.002 0.005 0.908 0.009 0.049 0.019 0.002 0.002 0.001]
 [0.    0.    0.005 0.006 0.949 0.    0.01  0.029 0.001 0.   ]
 [0.002 0.    0.009 0.049 0.007 0.91  0.002 0.021 0.    0.   ]
 [0.003 0.    0.003 0.002 0.    0.001 0.989 0.    0.001 0.001]
 [0.008 0.    0.008 0.003 0.007 0.025 0.    0.948 0.    0.001]
 [0.013 0.005 0.001 0.    0.    0.    0.    0.    0.978 0.003]
 [0.007 0.023 0.    0.    0.    0.    0.    0.    0.003 0.967]]
[2023-10-24 22:21:23,751 INFO] evaluation metric
[2023-10-24 22:21:23,752 INFO] acc: 0.9514
[2023-10-24 22:21:23,752 INFO] precision: 0.9515
[2023-10-24 22:21:23,753 INFO] recall: 0.9514
[2023-10-24 22:21:23,753 INFO] f1: 0.9513


{'acc': 0.9514,
 'precision': 0.9514999747717987,
 'recall': 0.9514000000000001,
 'f1': 0.951259687624711}

## Step 6: predict

In [7]:
y_pred, y_logits = trainer.predict(eval_loader)