In [1]:
import semilearn
from semilearn import get_dataset, get_data_loader, get_net_builder, get_algorithm, get_config, Trainer
from semilearn import BasicDataset
import sys
from torchvision import datasets,transforms
import torch
from semilearn.datasets.augmentation import RandAugment
import matplotlib.pyplot as plt
import numpy as np
from semilearn.datasets.utils import split_ssl_data

  from .autonotebook import tqdm as notebook_tqdm


## Step 1: define configs and create config

In [35]:
config = {
    'algorithm': 'fixmatch',
    'net': 'vit_tiny_patch2_32',
    'use_pretrain': False, 
    'pretrain_path': 'https://github.com/microsoft/Semi-supervised-learning/releases/download/v.0.0.0/vit_tiny_patch2_32_mlp_im_1k_32.pth',

    # optimization configs
    'epoch': 2,  
    'num_train_iter': 5000,  
    'num_eval_iter': 500,  
    'num_log_iter': 50,  
    'optim': 'AdamW',
    'lr': 5e-4,
    'layer_decay': 0.5,
    'batch_size': 16,
    'eval_batch_size': 16,


    # dataset configs
    'dataset': 'cifar10',
    'num_labels': 40,
    'num_classes': 10,
    'img_size': 32,
    'crop_ratio': 0.875,
    'data_dir': './data',
    'ulb_samples_per_class': None,

    # algorithm specific configs
    'hard_label': True,
    'uratio': 2,
    'ulb_loss_ratio': 1.0,

    # device configs
    'gpu': 0,
    'world_size': 1,
    'distributed': False,
    "num_workers": 2,
}
config = get_config(config)

/bin/sh: netstat: command not found


## Step 2: create model and specify algorithm

In [36]:
algorithm = get_algorithm(config,  get_net_builder(config.net, from_name=False), tb_log=None, logger=None)

Files already downloaded and verified
lb count: [4, 4, 4, 4, 4, 4, 4, 4, 4, 4]
ulb count: [5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000]
Files already downloaded and verified
unlabeled data number: 50000, labeled data number 40
Create train and test data loaders
[!] data loader keys: dict_keys(['train_lb', 'train_ulb', 'eval'])
Create optimizer and scheduler


## Step 3: create dataset

In [None]:
# cifar.py get_cifar()
#lb_data, lb_targets, ulb_data, ulb_targets = split_ssl_data()
dataset_dict = get_dataset(config, config.algorithm, config.dataset, config.num_labels, config.num_classes, data_dir=config.data_dir, include_lb_to_ulb=config.include_lb_to_ulb)
train_lb_loader = get_data_loader(config, dataset_dict['train_lb'], config.batch_size)
train_ulb_loader = get_data_loader(config, dataset_dict['train_ulb'], int(config.batch_size * config.uratio))
eval_loader = get_data_loader(config, dataset_dict['eval'], config.eval_batch_size)

### 3.1 config

In [39]:
target_class = ['AG','CC','GR','PC','SP']
train_path = f'/pscratch/sd/z/zhangtao/storm/mpc/key_paper/training'
test_path  = f'/pscratch/sd/z/zhangtao/storm/mpc/key_paper/test'
n = 224
n = 32
crop_ratio = config.crop_ratio
num_classes = 5

transform_eval = transforms.Compose([transforms.Resize((n,n)),transforms.ToTensor()])
transform = transforms.Compose([transforms.Resize((n,n))])
transform_weak = transforms.Compose([
    transforms.Resize(n),
    transforms.RandomCrop(n, padding=int(n * (1 - crop_ratio)), padding_mode='reflect'),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])

transform_strong = transforms.Compose([
    transforms.Resize(n),
    transforms.RandomCrop(n, padding=int(n * (1 - crop_ratio)), padding_mode='reflect'),
    transforms.RandomHorizontalFlip(),
    RandAugment(3, 5),
    transforms.ToTensor()
])

### 3.2 train dataloader

In [40]:
train_img = []
train_target = []

train_data = datasets.ImageFolder(f'{train_path}',transform)

for img, target in train_data:
    train_img.append(img)
    train_target.append(target)

train_img = np.array(train_img)
train_target = train_target

lb_data, lb_target, ulb_data, ulb_target = split_ssl_data(config, train_img, train_target, 
                                                          num_classes, config.num_labels, 
                                                          include_lb_to_ulb=config.include_lb_to_ulb)

del train_img, train_target

lb_dset = BasicDataset(config.algorithm, lb_data, lb_target, num_classes, transform_weak, False, transform_strong, False)
ulb_dset = BasicDataset(config.algorithm, ulb_data, ulb_target, num_classes, transform_weak, True, transform_strong, False)

train_lb_loader = get_data_loader(config, lb_dset, config.batch_size)
train_ulb_loader = get_data_loader(config, ulb_dset, int(config.batch_size * config.uratio))

lb_count = [0 for _ in range(num_classes)]
ulb_count = [0 for _ in range(num_classes)]
for c in lb_target:
    lb_count[c] += 1
for c in ulb_target:
    ulb_count[c] += 1
print("lb count: {}".format(lb_count))
print("ulb count: {}".format(ulb_count))

del lb_data, lb_target, ulb_data, ulb_target
del lb_dset, ulb_dset

lb count: [8, 8, 8, 8, 8]
ulb count: [4538, 4521, 4500, 4514, 4626]


### 3.3 test dataloader

In [41]:
test_img = []
test_target = []

test_data = datasets.ImageFolder(f'{test_path}',transform)

for img, target in test_data:
    test_img.append(img)
    test_target.append(target)

test_img = np.array(test_img)
test_target = test_target

eval_dset = BasicDataset(config.algorithm, test_img, test_target, num_classes, transform_eval, False, None, False)
eval_loader = get_data_loader(config, eval_dset, config.eval_batch_size)
del test_img, test_target, eval_dset

## Step 4: train

In [42]:
trainer = Trainer(config, algorithm)
trainer.fit(train_lb_loader, train_ulb_loader,eval_loader)

Epoch: 0
50 iteration USE_EMA: True, train/sup_loss: 1.4937, train/unsup_loss: 0.0000, train/total_loss: 1.4937, train/util_ratio: 0.0000, train/run_time: 0.0965, lr: 0.0005, train/prefetch_time: 0.0024 
100 iteration USE_EMA: True, train/sup_loss: 1.0346, train/unsup_loss: 0.0000, train/total_loss: 1.0346, train/util_ratio: 0.0000, train/run_time: 0.0960, lr: 0.0005, train/prefetch_time: 0.0022 
150 iteration USE_EMA: True, train/sup_loss: 0.9116, train/unsup_loss: 0.0000, train/total_loss: 0.9116, train/util_ratio: 0.0000, train/run_time: 0.0962, lr: 0.0005, train/prefetch_time: 0.0024 
200 iteration USE_EMA: True, train/sup_loss: 0.8372, train/unsup_loss: 0.0000, train/total_loss: 0.8372, train/util_ratio: 0.0000, train/run_time: 0.0958, lr: 0.0005, train/prefetch_time: 0.0023 
250 iteration USE_EMA: True, train/sup_loss: 0.6101, train/unsup_loss: 0.2556, train/total_loss: 0.8658, train/util_ratio: 0.0312, train/run_time: 0.0958, lr: 0.0005, train/prefetch_time: 0.0024 
300 iteratio

  _warn_prf(average, modifier, msg_start, len(result))


confusion matrix:
[[0.25  0.    0.    0.75  0.   ]
 [0.    0.    0.    1.    0.   ]
 [0.016 0.    0.    0.984 0.   ]
 [0.016 0.    0.    0.984 0.   ]
 [0.    0.    0.    1.    0.   ]]
model saved: ./saved_models/fixmatch/latest_model.pth
model saved: ./saved_models/fixmatch/model_best.pth
500 iteration, USE_EMA: True, train/sup_loss: 0.5592, train/unsup_loss: 0.0000, train/total_loss: 0.5592, train/util_ratio: 0.0000, train/run_time: 0.0959, eval/loss: 1.5791, eval/top-1-acc: 0.2468, eval/balanced_acc: 0.2468, eval/precision: 0.2190, eval/recall: 0.2468, eval/F1: 0.1468, lr: 0.0005, train/prefetch_time: 0.0676 BEST_EVAL_ACC: 0.2468, at 500 iters
550 iteration USE_EMA: True, train/sup_loss: 0.8808, train/unsup_loss: 0.0780, train/total_loss: 0.9588, train/util_ratio: 0.0312, train/run_time: 0.0966, lr: 0.0005, train/prefetch_time: 0.0022 
600 iteration USE_EMA: True, train/sup_loss: 0.4766, train/unsup_loss: 0.9231, train/total_loss: 1.3997, train/util_ratio: 0.1562, train/run_time: 0.0

  _warn_prf(average, modifier, msg_start, len(result))


confusion matrix:
[[0.478 0.022 0.    0.498 0.002]
 [0.    0.904 0.    0.076 0.02 ]
 [0.194 0.096 0.    0.622 0.088]
 [0.114 0.25  0.    0.606 0.03 ]
 [0.    0.958 0.    0.002 0.04 ]]
model saved: ./saved_models/fixmatch/latest_model.pth
model saved: ./saved_models/fixmatch/model_best.pth
1000 iteration, USE_EMA: True, train/sup_loss: 0.5431, train/unsup_loss: 0.4637, train/total_loss: 1.0068, train/util_ratio: 0.1562, train/run_time: 0.0971, eval/loss: 1.4889, eval/top-1-acc: 0.4056, eval/balanced_acc: 0.4056, eval/precision: 0.3143, eval/recall: 0.4056, eval/F1: 0.3190, lr: 0.0005, train/prefetch_time: 0.0023 BEST_EVAL_ACC: 0.4056, at 1000 iters
1050 iteration USE_EMA: True, train/sup_loss: 0.3498, train/unsup_loss: 0.4810, train/total_loss: 0.8308, train/util_ratio: 0.2188, train/run_time: 0.0976, lr: 0.0005, train/prefetch_time: 0.0022 
1100 iteration USE_EMA: True, train/sup_loss: 0.1982, train/unsup_loss: 0.3986, train/total_loss: 0.5968, train/util_ratio: 0.1875, train/run_time:

[2024-01-12 17:23:51,914 INFO] confusion matrix
[2024-01-12 17:23:51,916 INFO] [[0.742 0.006 0.044 0.204 0.004]
 [0.028 0.594 0.012 0.128 0.238]
 [0.08  0.052 0.718 0.044 0.106]
 [0.418 0.048 0.048 0.428 0.058]
 [0.    0.194 0.    0.03  0.776]]
[2024-01-12 17:23:51,918 INFO] evaluation metric
[2024-01-12 17:23:51,918 INFO] acc: 0.6516
[2024-01-12 17:23:51,919 INFO] precision: 0.6586
[2024-01-12 17:23:51,919 INFO] recall: 0.6516
[2024-01-12 17:23:51,920 INFO] f1: 0.6495


model saved: ./saved_models/fixmatch/latest_model.pth
model saved: ./saved_models/fixmatch/model_best.pth
Epoch: 1
2550 iteration USE_EMA: True, train/sup_loss: 0.3127, train/unsup_loss: 0.6687, train/total_loss: 0.9814, train/util_ratio: 0.5000, train/run_time: 0.0979, lr: 0.0004, train/prefetch_time: 0.0022 
2600 iteration USE_EMA: True, train/sup_loss: 0.2053, train/unsup_loss: 0.3388, train/total_loss: 0.5442, train/util_ratio: 0.3750, train/run_time: 0.0984, lr: 0.0004, train/prefetch_time: 0.0024 
2650 iteration USE_EMA: True, train/sup_loss: 0.1463, train/unsup_loss: 0.6436, train/total_loss: 0.7899, train/util_ratio: 0.5312, train/run_time: 0.0977, lr: 0.0004, train/prefetch_time: 0.0021 
2700 iteration USE_EMA: True, train/sup_loss: 0.0668, train/unsup_loss: 0.4209, train/total_loss: 0.4877, train/util_ratio: 0.3438, train/run_time: 0.0985, lr: 0.0004, train/prefetch_time: 0.0024 
2750 iteration USE_EMA: True, train/sup_loss: 0.3384, train/unsup_loss: 0.5480, train/total_loss:

[2024-01-12 17:31:28,431 INFO] confusion matrix
[2024-01-12 17:31:28,432 INFO] [[0.58  0.004 0.104 0.312 0.   ]
 [0.02  0.594 0.014 0.12  0.252]
 [0.03  0.024 0.78  0.042 0.124]
 [0.33  0.038 0.098 0.466 0.068]
 [0.    0.188 0.    0.016 0.796]]
[2024-01-12 17:31:28,434 INFO] evaluation metric
[2024-01-12 17:31:28,434 INFO] acc: 0.6432
[2024-01-12 17:31:28,434 INFO] precision: 0.6434
[2024-01-12 17:31:28,435 INFO] recall: 0.6432
[2024-01-12 17:31:28,435 INFO] f1: 0.6407
[2024-01-12 17:31:28,867 INFO] Best acc 0.6640 at epoch 0
[2024-01-12 17:31:28,868 INFO] Training finished.


model saved: ./saved_models/fixmatch/latest_model.pth


## Step 5: evaluate

In [43]:
trainer.evaluate(eval_loader)

[2024-01-12 17:32:28,994 INFO] confusion matrix
[2024-01-12 17:32:28,996 INFO] [[0.58  0.004 0.104 0.312 0.   ]
 [0.02  0.594 0.014 0.12  0.252]
 [0.03  0.024 0.78  0.042 0.124]
 [0.33  0.038 0.098 0.466 0.068]
 [0.    0.188 0.    0.016 0.796]]
[2024-01-12 17:32:28,998 INFO] evaluation metric
[2024-01-12 17:32:28,998 INFO] acc: 0.6432
[2024-01-12 17:32:28,999 INFO] precision: 0.6434
[2024-01-12 17:32:28,999 INFO] recall: 0.6432
[2024-01-12 17:32:29,000 INFO] f1: 0.6407


{'acc': 0.6432,
 'precision': 0.6434308155032188,
 'recall': 0.6432,
 'f1': 0.6406907814209447}

## Step 6: predict

In [22]:
y_pred, y_logits = trainer.predict(eval_loader)