# Semi-Supervised Learning

코드 출처 : https://github.com/perrying/realistic-ssl-evaluation-pytorch

본 튜토리얼은 위의 코드를 참고하여 작성 되었습니다.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

import argparse, math, time, json, os

from lib import wrn, transform

  from .autonotebook import tqdm as notebook_tqdm


## 0. Dataset 준비

<img src=https://user-images.githubusercontent.com/35906602/209686848-82a1e67e-33dd-4036-aaf9-26493b442a0b.png width="600">

* 이미지 출처 : https://gruuuuu.github.io/machine-learning/cifar10-cnn/

본 튜토리얼에서는 Cifar-10 데이터셋이 활용되었습니다. Cifar-10 데이터셋은 32x32 픽셀의 60,000개의 컬러 이미지로 구성된 데이터로, 각 이미지는 총 10개의 클래스로 라벨링 되어 있습니다.

In [2]:
from torchvision import datasets
import argparse, os
import numpy as np

parser = argparse.ArgumentParser()
parser.add_argument("--seed", "-s", default=1, type=int, help="random seed")
parser.add_argument("--dataset", "-d", default="cifar10", type=str, help="dataset name : [cifar10]") # Cifar10 사용
parser.add_argument("--nlabels", "-n", default=1000, type=int, help="the number of labeled data")
args = parser.parse_args("")

COUNTS = {
    "cifar10": {"train": 50000, "test": 10000, "valid": 5000, "extra": 0},
    "imagenet_32": {
        "train": 1281167,
        "test": 50000,
        "valid": 50050,
        "extra": 0,
    },
}

_DATA_DIR = "./data"

def split_l_u(train_set, n_labels):
    # NOTE: this function assume that train_set is shuffled.
    images = train_set["images"]
    labels = train_set["labels"]
    classes = np.unique(labels)
    n_labels_per_cls = n_labels // len(classes)
    l_images = []
    l_labels = []
    u_images = []
    u_labels = []
    for c in classes:
        cls_mask = (labels == c)
        c_images = images[cls_mask]
        c_labels = labels[cls_mask]
        l_images += [c_images[:n_labels_per_cls]]
        l_labels += [c_labels[:n_labels_per_cls]]
        u_images += [c_images[n_labels_per_cls:]]
        u_labels += [np.zeros_like(c_labels[n_labels_per_cls:]) - 1] # dammy label
    l_train_set = {"images": np.concatenate(l_images, 0), "labels": np.concatenate(l_labels, 0)}
    u_train_set = {"images": np.concatenate(u_images, 0), "labels": np.concatenate(u_labels, 0)}
    return l_train_set, u_train_set

def _load_cifar10():
    splits = {}
    for train in [True, False]:
        tv_data = datasets.CIFAR10(_DATA_DIR, train, download=True)
        data = {}
        data["images"] = tv_data.data
        data["labels"] = np.array(tv_data.targets)
        splits["train" if train else "test"] = data
    return splits.values()

def gcn(images, multiplier=55, eps=1e-10):
    # global contrast normalization
    images = images.astype(np.float)
    images -= images.mean(axis=(1,2,3), keepdims=True)
    per_image_norm = np.sqrt(np.square(images).sum((1,2,3), keepdims=True))
    per_image_norm[per_image_norm < eps] = 1
    return multiplier * images / per_image_norm

def get_zca_normalization_param(images, scale=0.1, eps=1e-10):
    n_data, height, width, channels = images.shape
    images = images.reshape(n_data, height*width*channels)
    image_cov = np.cov(images, rowvar=False)
    U, S, _ = np.linalg.svd(image_cov + scale * np.eye(image_cov.shape[0]))
    zca_decomp = np.dot(U, np.dot(np.diag(1/np.sqrt(S + eps)), U.T))
    mean = images.mean(axis=0)
    return mean, zca_decomp

def zca_normalization(images, mean, decomp):
    n_data, height, width, channels = images.shape
    images = images.reshape(n_data, -1)
    images = np.dot((images - mean), decomp)
    return images.reshape(n_data, height, width, channels)

rng = np.random.RandomState(args.seed)

validation_count = COUNTS[args.dataset]["valid"]

extra_set = None  # In general, there won't be extra data.

train_set, test_set = _load_cifar10()
train_set["images"] = gcn(train_set["images"])
test_set["images"] = gcn(test_set["images"])
mean, zca_decomp = get_zca_normalization_param(train_set["images"])
train_set["images"] = zca_normalization(train_set["images"], mean, zca_decomp)
test_set["images"] = zca_normalization(test_set["images"], mean, zca_decomp)
# N x H x W x C -> N x C x H x W
train_set["images"] = np.transpose(train_set["images"], (0,3,1,2))
test_set["images"] = np.transpose(test_set["images"], (0,3,1,2))

# permute index of training set
indices = rng.permutation(len(train_set["images"]))
train_set["images"] = train_set["images"][indices]
train_set["labels"] = train_set["labels"][indices]

if extra_set is not None:
    extra_indices = rng.permutation(len(extra_set["images"]))
    extra_set["images"] = extra_set["images"][extra_indices]
    extra_set["labels"] = extra_set["labels"][extra_indices]

# split training set into training and validation
train_images = train_set["images"][validation_count:]
train_labels = train_set["labels"][validation_count:]
validation_images = train_set["images"][:validation_count]
validation_labels = train_set["labels"][:validation_count]
validation_set = {"images": validation_images, "labels": validation_labels}
train_set = {"images": train_images, "labels": train_labels}

# split training set into labeled data and unlabeled data
l_train_set, u_train_set = split_l_u(train_set, args.nlabels)

if not os.path.exists(os.path.join(_DATA_DIR, args.dataset)):
    os.mkdir(os.path.join(_DATA_DIR, args.dataset))

np.save(os.path.join(_DATA_DIR, args.dataset, "l_train"), l_train_set)
np.save(os.path.join(_DATA_DIR, args.dataset, "u_train"), u_train_set)
np.save(os.path.join(_DATA_DIR, args.dataset, "val"), validation_set)
np.save(os.path.join(_DATA_DIR, args.dataset, "test"), test_set)
if extra_set is not None:
    np.save(os.path.join(_DATA_DIR, args.dataset, "extra"), extra_set)


Files already downloaded and verified
Files already downloaded and verified


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations


## 00. 공통 설정

밑의 파라미터는 모든 학습에서 공유하게 됩니다. 

기존에는 50만 이상의 많은 수의 Iteration을 필요로 하나, 학습 시간의 문제로 본 튜토리얼에서는 10,000번의 Iteration으로 학습하게 됩니다. 따라서 기존에 알려진 성능보다 낮은 성능을 보인다는 점을 주의해주세요.

In [3]:
device = "cuda"

In [4]:
parser = argparse.ArgumentParser()
parser.add_argument("--alg", "-a", default="VAT", type=str, help="ssl algorithm : [supervised, PI, MT, VAT, PL, ICT]")
parser.add_argument("--em", default=0, type=float, help="coefficient of entropy minimization. If you try VAT + EM, set 0.06")
parser.add_argument("--validation", default=1000, type=int, help="validate at this interval (default 25000)")
parser.add_argument("--dataset", "-d", default="cifar10", type=str, help="dataset name : [cifar10]")
parser.add_argument("--root", "-r", default="data", type=str, help="dataset dir")
parser.add_argument("--output", "-o", default="./exp_res", type=str, help="output dir")
args = parser.parse_args("")

In [5]:
shared_config = {
    "iteration" : 10000,
    "warmup" : 4000,
    "lr_decay_iter" : 8000,
    "lr_decay_factor" : 0.2,
    "batch_size" : 100,
}

In [6]:
class RandomSampler(torch.utils.data.Sampler):
    """ sampling without replacement """
    def __init__(self, num_data, num_sample):
        iterations = num_sample // num_data + 1
        self.indices = torch.cat([torch.randperm(num_data) for _ in range(iterations)]).tolist()[:num_sample]

    def __iter__(self):
        return iter(self.indices)

    def __len__(self):
        return len(self.indices)

### Dataset

딥러닝 모델을 학습시키기 위한 데이터셋을 구성하는 단계입니다.

In [7]:
class CIFAR10:
    def __init__(self, root, split="l_train"):
        self.dataset = np.load(os.path.join(root, "cifar10", split+".npy"), allow_pickle=True).item()

    def __getitem__(self, idx):
        image = self.dataset["images"][idx]
        label = self.dataset["labels"][idx]
        return image, label

    def __len__(self):
        return len(self.dataset["images"])

In [8]:
cifar10_config = {
    "transform" : [True, True, True], # 차례대로 Horizontal flip, Random crop, Gaussian Noise를 의미합니다.
    "dataset" : CIFAR10,
    "num_classes" : 10,
}

In [9]:
dataset_cfg = cifar10_config
transform_fn = transform.transform(*dataset_cfg["transform"])

l_train_dataset = dataset_cfg["dataset"](args.root, "l_train")
u_train_dataset = dataset_cfg["dataset"](args.root, "u_train")
val_dataset = dataset_cfg["dataset"](args.root, "val")
test_dataset = dataset_cfg["dataset"](args.root, "test")

print("labeled data : {}, unlabeled data : {}, training data : {}".format(
    len(l_train_dataset), len(u_train_dataset), len(l_train_dataset)+len(u_train_dataset)))
print("validation data : {}, test data : {}".format(len(val_dataset), len(test_dataset)))


holizontal flip : True, random crop : True, gaussian noise : True
labeled data : 1000, unlabeled data : 44000, training data : 45000
validation data : 5000, test data : 10000


In [10]:
l_loader = DataLoader(
    l_train_dataset, shared_config["batch_size"], drop_last=True,
    sampler=RandomSampler(len(l_train_dataset), shared_config["iteration"] * shared_config["batch_size"])
)

u_loader = DataLoader(
    u_train_dataset, shared_config["batch_size"]//2, drop_last=True,
    sampler=RandomSampler(len(u_train_dataset), shared_config["iteration"] * shared_config["batch_size"]//2)
)

val_loader = DataLoader(val_dataset, 128, shuffle=False, drop_last=False)
test_loader = DataLoader(test_dataset, 128, shuffle=False, drop_last=False)

## 1. VAT

<img src="https://user-images.githubusercontent.com/35906602/209687878-59d1e148-96db-4959-90ae-53096a6e8e71.png" width="700">

Virtual Adversarial Training이란 Adversarial Training을 Semi-Supervised Learning에 접목한 방법론입니다. 라벨이 없는 데이터에 가상의 적대적 방향을 정의하고, 이 방향을 이용해 Adversarial Training을 수행하게 됩니다.

참고 : https://creamnuts.github.io/short%20review/vat/

### Config

In [11]:
args.alg = 'VAT'
vat_config = {
    # virtual adversarial training
    "xi" : 1e-6,
    "eps" : {"cifar10":6},
    "consis_coef" : 0.3,
    "lr" : 3e-3
}
alg_cfg = vat_config

### Model

In [12]:
class VAT(nn.Module):
    def __init__(self, eps=1.0, xi=1e-6, n_iteration=1):
        super().__init__()
        self.eps = eps
        self.xi = xi
        self.n_iteration = n_iteration

    def kld(self, q_logit, p_logit):
        q = q_logit.softmax(1)
        qlogp = (q * self.__logsoftmax(p_logit)).sum(1)
        qlogq = (q * self.__logsoftmax(q_logit)).sum(1)
        return qlogq - qlogp

    def normalize(self, v):
        v = v / (1e-12 + self.__reduce_max(v.abs(), range(1, len(v.shape))))
        v = v / (1e-6 + v.pow(2).sum((1,2,3),keepdim=True)).sqrt()
        return v

    def forward(self, x, y, model, mask):
        model.update_batch_stats(False)
        d = torch.randn_like(x)
        d = self.normalize(d)
        for _ in range(self.n_iteration):
            d.requires_grad = True
            x_hat = x + self.xi * d
            y_hat = model(x_hat)
            kld = self.kld(y.detach(), y_hat).mean()
            d = torch.autograd.grad(kld, d)[0]
            d = self.normalize(d).detach()
        x_hat = x + self.eps * d
        y_hat = model(x_hat)
        # NOTE:
        # Original implimentation of VAT defines KL(P(y|x)||P(x|x+r_adv)) as loss function
        # However, Avital Oliver's implimentation use KL(P(y|x+r_adv)||P(y|x)) as loss function of VAT
        # see issue https://github.com/brain-research/realistic-ssl-evaluation/issues/27
        loss = (self.kld(y_hat, y.detach()) * mask).mean()
        model.update_batch_stats(True)
        return loss

    def __reduce_max(self, v, idx_list):
        for i in idx_list:
            v = v.max(i, keepdim=True)[0]
        return v

    def __logsoftmax(self,x):
        xdev = x - x.max(1, keepdim=True)[0]
        lsm = xdev - xdev.exp().sum(1, keepdim=True).log()
        return lsm

In [13]:
model = wrn.WRN(2, dataset_cfg["num_classes"], transform_fn).to(device)
optimizer = optim.Adam(model.parameters(), lr=alg_cfg["lr"])

trainable_paramters = sum([p.data.nelement() for p in model.parameters()])
print("trainable parameters : {}".format(trainable_paramters))

trainable parameters : 1467610


<img src="https://user-images.githubusercontent.com/35906602/209687686-e6183eae-6877-4d58-af50-f76481117eb9.png" width="800">

지도학습 모델은 WideResnet을 사용하게 됩니다. 이 튜토리얼에서는 SSL 방법론의 성능을 비교하기 위해 VAT를 비롯한 5개 SSL 모델 모두에서 해당 모델을 Backbone 모델로 통일하였습니다.

In [14]:
ssl_obj = VAT(alg_cfg["eps"][args.dataset], alg_cfg["xi"], 1)

### Train

In [15]:
###
VAT_cls_loss_list = []
VAT_ssl_loss_list = []
VAT_loss_list = []

VAT_val_acc_list = []
###

iteration = 0
maximum_val_acc = 0
s = time.time()
for l_data, u_data in zip(l_loader, u_loader):
    iteration += 1
    l_input, target = l_data
    l_input, target = l_input.to(device).float(), target.to(device).long()

    u_input, dummy_target = u_data
    u_input, dummy_target = u_input.to(device).float(), dummy_target.to(device).long()

    target = torch.cat([target, dummy_target], 0)
    unlabeled_mask = (target == -1).float()

    inputs = torch.cat([l_input, u_input], 0)
    outputs = model(inputs)

    # ramp up exp(-5(1 - t)^2)
    coef = alg_cfg["consis_coef"] * math.exp(-5 * (1 - min(iteration/shared_config["warmup"], 1))**2)
    ssl_loss = ssl_obj(inputs, outputs.detach(), model, unlabeled_mask) * coef

    # supervised loss
    cls_loss = F.cross_entropy(outputs, target, reduction="none", ignore_index=-1).mean()

    loss = cls_loss + ssl_loss
    
    ###
    VAT_cls_loss_list.append(cls_loss)
    VAT_ssl_loss_list.append(ssl_loss)
    VAT_loss_list.append(loss)
    ###

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    # 표기
    if iteration == 1 or (iteration % 100) == 0:
        wasted_time = time.time() - s
        rest = (shared_config["iteration"] - iteration)/100 * wasted_time / 60
        print("iteration [{}/{}] cls loss : {:.6e}, SSL loss : {:.6e}, coef : {:.5e}, time : {:.3f} iter/sec, rest : {:.3f} min, lr : {}".format(
            iteration, shared_config["iteration"], cls_loss.item(), ssl_loss.item(), coef, 100 / wasted_time, rest, optimizer.param_groups[0]["lr"]),
            "\r", end="")
        s = time.time()
        
    # Validation
    if (iteration % args.validation) == 0 or iteration == shared_config["iteration"]:
        with torch.no_grad():
            model.eval()
            print()
            print("### validation ###")
            sum_acc = 0.
            s = time.time()
            for j, data in enumerate(val_loader):
                input, target = data
                input, target = input.to(device).float(), target.to(device).long()

                output = model(input)

                pred_label = output.max(1)[1]
                sum_acc += (pred_label == target).float().sum()
                if ((j+1) % 10) == 0:
                    d_p_s = 10/(time.time()-s)
                    print("[{}/{}] time : {:.1f} data/sec, rest : {:.2f} sec".format(
                        j+1, len(val_loader), d_p_s, (len(val_loader) - j-1)/d_p_s
                    ), "\r", end="")
                    s = time.time()
            acc = sum_acc/float(len(val_dataset))
            print()
            print("varidation accuracy : {}".format(acc))
            
            ###
            VAT_val_acc_list.append(acc)
            ###
            
            # Test
            if maximum_val_acc < acc:
                print("### test ###")
                maximum_val_acc = acc
                sum_acc = 0.
                s = time.time()
                for j, data in enumerate(test_loader):
                    input, target = data
                    input, target = input.to(device).float(), target.to(device).long()
                    output = model(input)
                    pred_label = output.max(1)[1]
                    sum_acc += (pred_label == target).float().sum()
                    if ((j+1) % 10) == 0:
                        d_p_s = 100/(time.time()-s)
                        print("[{}/{}] time : {:.1f} data/sec, rest : {:.2f} sec".format(
                            j+1, len(test_loader), d_p_s, (len(test_loader) - j-1)/d_p_s
                        ), "\r", end="")
                        s = time.time()
                print()
                test_acc = sum_acc / float(len(test_dataset))
                print("test accuracy : {}".format(test_acc))
                # torch.save(model.state_dict(), os.path.join(args.output, "best_model.pth"))
        model.train()
        s = time.time()
        
    # lr decay
    if iteration == shared_config["lr_decay_iter"]:
        optimizer.param_groups[0]["lr"] *= shared_config["lr_decay_factor"]    
        
print("test acc : {}".format(test_acc))

iteration [1000/10000] cls loss : 1.036135e-02, SSL loss : 1.539129e-02, coef : 1.80164e-02, time : 13.852 iter/sec, rest : 10.829 min, lr : 0.003 
### validation ###
[40/40] time : 149.1 data/sec, rest : 0.00 sec 
varidation accuracy : 0.5533999800682068
### test ###
[70/79] time : 1537.1 data/sec, rest : 0.01 sec 
test accuracy : 0.5339999794960022
iteration [2000/10000] cls loss : 2.781086e-02, SSL loss : 6.056770e-02, coef : 8.59514e-02, time : 12.757 iter/sec, rest : 10.452 min, lr : 0.003 
### validation ###
[40/40] time : 153.7 data/sec, rest : 0.00 sec 
varidation accuracy : 0.5149999856948853
iteration [3000/10000] cls loss : 1.537405e-02, SSL loss : 2.573695e-01, coef : 2.19485e-01, time : 13.394 iter/sec, rest : 8.710 min, lr : 0.003  
### validation ###
[40/40] time : 146.9 data/sec, rest : 0.00 sec 
varidation accuracy : 0.5491999983787537
iteration [4000/10000] cls loss : 3.041661e-02, SSL loss : 1.678324e-01, coef : 3.00000e-01, time : 13.083 iter/sec, rest : 7.644 min, 

## 2. Mean Teacher

<img src="https://user-images.githubusercontent.com/35906602/209688575-a6512a9a-9f4d-43eb-884e-b105c70e1477.png" width="1000">

* Reference : https://nuguziii.github.io/paper-review/PR-009/
* Tarvainen and Valpora. Mean teachers are better role models: Weighted-averaged consistency targets improve semi-supervised deep learning results. NIPS 2017

같은 구조를 가지는 2개의 모델 student model과 teacher model이 존재하며, Student model은 labeled data를 Input으로 받으며, teacher model은 unlabeled data를 Input으로 받게 됩니다. 

Student model은 지도 학습 기반의 손실 함수 및 teacher model과의 consistency loss로 학습이 되며, teacher model은 student model의 parameter를 지수 이동 평균하여 update하기 때문에 역전파가 진행되지 않음

### Config

In [16]:
args.alg = 'MT'
mt_config = {
    # mean teacher
    "ema_factor" : 0.95,
    "lr" : 4e-4,
    "consis_coef" : 8,
}
alg_cfg = mt_config

### Model

In [17]:
class MT(nn.Module):
    def __init__(self, model, ema_factor):
        super().__init__()
        self.model = model
        self.model.train()
        self.ema_factor = ema_factor
        self.global_step = 0

    def forward(self, x, y, model, mask):
        self.global_step += 1
        y_hat = self.model(x)
        model.update_batch_stats(False)
        y = model(x) # recompute y since y as input of forward function is detached
        model.update_batch_stats(True)
        return (F.mse_loss(y.softmax(1), y_hat.softmax(1).detach(), reduction="none").mean(1) * mask).mean()

    def moving_average(self, parameters):
        ema_factor = min(1 - 1 / (self.global_step+1), self.ema_factor)
        for emp_p, p in zip(self.model.parameters(), parameters):
            emp_p.data = ema_factor * emp_p.data + (1 - ema_factor) * p.data

In [18]:
model = wrn.WRN(2, dataset_cfg["num_classes"], transform_fn).to(device)
optimizer = optim.Adam(model.parameters(), lr=alg_cfg["lr"])

trainable_paramters = sum([p.data.nelement() for p in model.parameters()])
print("trainable parameters : {}".format(trainable_paramters))

trainable parameters : 1467610


Mean Teacher는 Student 모델과 Teacher 모델이 존재하게 됩니다. 

In [19]:
t_model = wrn.WRN(2, dataset_cfg["num_classes"], transform_fn).to(device)
t_model.load_state_dict(model.state_dict())
ssl_obj = MT(t_model, alg_cfg["ema_factor"])

### Train

In [None]:
###
MT_cls_loss_list = []
MT_ssl_loss_list = []
MT_loss_list = []

MT_val_acc_list = []
###

iteration = 0
maximum_val_acc = 0
s = time.time()
for l_data, u_data in zip(l_loader, u_loader):
    iteration += 1
    l_input, target = l_data
    l_input, target = l_input.to(device).float(), target.to(device).long()

    u_input, dummy_target = u_data
    u_input, dummy_target = u_input.to(device).float(), dummy_target.to(device).long()

    target = torch.cat([target, dummy_target], 0)
    unlabeled_mask = (target == -1).float()

    inputs = torch.cat([l_input, u_input], 0)
    outputs = model(inputs)

    # ramp up exp(-5(1 - t)^2)
    coef = alg_cfg["consis_coef"] * math.exp(-5 * (1 - min(iteration/shared_config["warmup"], 1))**2)
    ssl_loss = ssl_obj(inputs, outputs.detach(), model, unlabeled_mask) * coef

    # supervised loss
    cls_loss = F.cross_entropy(outputs, target, reduction="none", ignore_index=-1).mean()

    loss = cls_loss + ssl_loss
    
    ###
    MT_cls_loss_list.append(cls_loss)
    MT_ssl_loss_list.append(ssl_loss)
    MT_loss_list.append(loss)
    ###

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    # 표기
    if iteration == 1 or (iteration % 100) == 0:
        wasted_time = time.time() - s
        rest = (shared_config["iteration"] - iteration)/100 * wasted_time / 60
        print("iteration [{}/{}] cls loss : {:.6e}, SSL loss : {:.6e}, coef : {:.5e}, time : {:.3f} iter/sec, rest : {:.3f} min, lr : {}".format(
            iteration, shared_config["iteration"], cls_loss.item(), ssl_loss.item(), coef, 100 / wasted_time, rest, optimizer.param_groups[0]["lr"]),
            "\r", end="")
        s = time.time()
        
    # Validation
    if (iteration % args.validation) == 0 or iteration == shared_config["iteration"]:
        with torch.no_grad():
            model.eval()
            print()
            print("### validation ###")
            sum_acc = 0.
            s = time.time()
            for j, data in enumerate(val_loader):
                input, target = data
                input, target = input.to(device).float(), target.to(device).long()

                output = model(input)

                pred_label = output.max(1)[1]
                sum_acc += (pred_label == target).float().sum()
                if ((j+1) % 10) == 0:
                    d_p_s = 10/(time.time()-s)
                    print("[{}/{}] time : {:.1f} data/sec, rest : {:.2f} sec".format(
                        j+1, len(val_loader), d_p_s, (len(val_loader) - j-1)/d_p_s
                    ), "\r", end="")
                    s = time.time()
            acc = sum_acc/float(len(val_dataset))
            print()
            print("varidation accuracy : {}".format(acc))
            
            ###
            MT_val_acc_list.append(acc)
            ###
            
            # Test
            if maximum_val_acc < acc:
                print("### test ###")
                maximum_val_acc = acc
                sum_acc = 0.
                s = time.time()
                for j, data in enumerate(test_loader):
                    input, target = data
                    input, target = input.to(device).float(), target.to(device).long()
                    output = model(input)
                    pred_label = output.max(1)[1]
                    sum_acc += (pred_label == target).float().sum()
                    if ((j+1) % 10) == 0:
                        d_p_s = 100/(time.time()-s)
                        print("[{}/{}] time : {:.1f} data/sec, rest : {:.2f} sec".format(
                            j+1, len(test_loader), d_p_s, (len(test_loader) - j-1)/d_p_s
                        ), "\r", end="")
                        s = time.time()
                print()
                test_acc = sum_acc / float(len(test_dataset))
                print("test accuracy : {}".format(test_acc))
                # torch.save(model.state_dict(), os.path.join(args.output, "best_model.pth"))
        model.train()
        s = time.time()
        
    # lr decay
    if iteration == shared_config["lr_decay_iter"]:
        optimizer.param_groups[0]["lr"] *= shared_config["lr_decay_factor"]    
        
print("test acc : {}".format(test_acc))

iteration [1000/10000] cls loss : 2.663326e-02, SSL loss : 1.042584e-02, coef : 4.80437e-01, time : 17.386 iter/sec, rest : 8.628 min, lr : 0.0004 
### validation ###
[40/40] time : 133.2 data/sec, rest : 0.00 sec 
varidation accuracy : 0.4883999824523926
### test ###
[70/79] time : 1332.1 data/sec, rest : 0.01 sec 
test accuracy : 0.4846999943256378
iteration [1100/10000] cls loss : 3.664000e-02, SSL loss : 1.322712e-02, coef : 5.77710e-01, time : 16.441 iter/sec, rest : 9.022 min, lr : 0.0004 

## 3. ${\Pi}$ Model

<img src="https://user-images.githubusercontent.com/35906602/209688968-9a1acdfe-12bb-43d0-b3de-fe02c159c27e.png" width=1000>

* Reference : https://nuguziii.github.io/paper-review/PR-009/
* Laine and Alia. Temporal Ensembling for Semi-Supervised Learning. ICLR 2017

$\Pi$ Model에서는 같은 input에 대해서는 noise가 적용되어도 비슷한 결과를 보여야 한다는 것에서 착안, stochastic augmentation을 각각 다르게 적용합니다. 

Stochastic Augmentation과 Dropout을 이용해 동일한 입력 $x_i$ 에서 다른 출력 $z_i$ 와 $\tilde{z}_i$이 나타납니다.

다만 Training Target이 네트워크의 하나의 evaluation에 의해 얻어지기 때문에 noisy 하다는 문제점이 있습니다.

### Config

In [None]:
args.alg = 'pi'
pi_config = {
    # Pi Model
    "lr" : 3e-4,
    "consis_coef" : 20.0,
}
alg_cfg = pi_config

### Model

In [None]:
class PiModel(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, y, model, mask):
        # NOTE:
        # stochastic transformation is embeded in forward function
        # so, pi-model is just to calculate consistency between two outputs
        model.update_batch_stats(False)
        y_hat = model(x)
        model.update_batch_stats(True)
        return (F.mse_loss(y_hat.softmax(1), y.softmax(1).detach(), reduction="none").mean(1) * mask).mean()

In [None]:
model = wrn.WRN(2, dataset_cfg["num_classes"], transform_fn).to(device)
optimizer = optim.Adam(model.parameters(), lr=alg_cfg["lr"])

trainable_paramters = sum([p.data.nelement() for p in model.parameters()])
print("trainable parameters : {}".format(trainable_paramters))

In [None]:
ssl_obj = PiModel()

### Train

In [None]:
###
pi_cls_loss_list = []
pi_ssl_loss_list = []
pi_loss_list = []

pi_val_acc_list = []
###

iteration = 0
maximum_val_acc = 0
s = time.time()
for l_data, u_data in zip(l_loader, u_loader):
    iteration += 1
    l_input, target = l_data
    l_input, target = l_input.to(device).float(), target.to(device).long()

    u_input, dummy_target = u_data
    u_input, dummy_target = u_input.to(device).float(), dummy_target.to(device).long()

    target = torch.cat([target, dummy_target], 0)
    unlabeled_mask = (target == -1).float()

    inputs = torch.cat([l_input, u_input], 0)
    outputs = model(inputs)

    # ramp up exp(-5(1 - t)^2)
    coef = alg_cfg["consis_coef"] * math.exp(-5 * (1 - min(iteration/shared_config["warmup"], 1))**2)
    ssl_loss = ssl_obj(inputs, outputs.detach(), model, unlabeled_mask) * coef

    # supervised loss
    cls_loss = F.cross_entropy(outputs, target, reduction="none", ignore_index=-1).mean()

    loss = cls_loss + ssl_loss
    
    ###
    pi_cls_loss_list.append(cls_loss)
    pi_ssl_loss_list.append(ssl_loss)
    pi_loss_list.append(loss)
    ###

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    # 표기
    if iteration == 1 or (iteration % 100) == 0:
        wasted_time = time.time() - s
        rest = (shared_config["iteration"] - iteration)/100 * wasted_time / 60
        print("iteration [{}/{}] cls loss : {:.6e}, SSL loss : {:.6e}, coef : {:.5e}, time : {:.3f} iter/sec, rest : {:.3f} min, lr : {}".format(
            iteration, shared_config["iteration"], cls_loss.item(), ssl_loss.item(), coef, 100 / wasted_time, rest, optimizer.param_groups[0]["lr"]),
            "\r", end="")
        s = time.time()
        
    # Validation
    if (iteration % args.validation) == 0 or iteration == shared_config["iteration"]:
        with torch.no_grad():
            model.eval()
            print()
            print("### validation ###")
            sum_acc = 0.
            s = time.time()
            for j, data in enumerate(val_loader):
                input, target = data
                input, target = input.to(device).float(), target.to(device).long()

                output = model(input)

                pred_label = output.max(1)[1]
                sum_acc += (pred_label == target).float().sum()
                if ((j+1) % 10) == 0:
                    d_p_s = 10/(time.time()-s)
                    print("[{}/{}] time : {:.1f} data/sec, rest : {:.2f} sec".format(
                        j+1, len(val_loader), d_p_s, (len(val_loader) - j-1)/d_p_s
                    ), "\r", end="")
                    s = time.time()
            acc = sum_acc/float(len(val_dataset))
            print()
            print("varidation accuracy : {}".format(acc))
            
            ###
            pi_val_acc_list.append(acc)
            ###
            
            # Test
            if maximum_val_acc < acc:
                print("### test ###")
                maximum_val_acc = acc
                sum_acc = 0.
                s = time.time()
                for j, data in enumerate(test_loader):
                    input, target = data
                    input, target = input.to(device).float(), target.to(device).long()
                    output = model(input)
                    pred_label = output.max(1)[1]
                    sum_acc += (pred_label == target).float().sum()
                    if ((j+1) % 10) == 0:
                        d_p_s = 100/(time.time()-s)
                        print("[{}/{}] time : {:.1f} data/sec, rest : {:.2f} sec".format(
                            j+1, len(test_loader), d_p_s, (len(test_loader) - j-1)/d_p_s
                        ), "\r", end="")
                        s = time.time()
                print()
                test_acc = sum_acc / float(len(test_dataset))
                print("test accuracy : {}".format(test_acc))
                # torch.save(model.state_dict(), os.path.join(args.output, "best_model.pth"))
        model.train()
        s = time.time()
        
    # lr decay
    if iteration == shared_config["lr_decay_iter"]:
        optimizer.param_groups[0]["lr"] *= shared_config["lr_decay_factor"]    
        
print("test acc : {}".format(test_acc))

## 4. ICT

<img src="https://user-images.githubusercontent.com/35906602/209689311-5deeee55-ae48-4fcb-80e6-720febca5e03.png" width="1000">

* Reference : https://jiwunghyun.medium.com/semi-supervised-learning-%EC%A0%95%EB%A6%AC-a7ed58a8f023
* Vikas Verma et al. Interpolation Consistency Training for Semi-Supervised Learning. IJCAI 2019

Mixup을 Semi-supervised learning에 적용한 방법입니다. (Mixup한 데이터에 대한 모델 결과)와 (unlabeled sample의 모델 결과의 Mixup) 차이가 consistency loss가 됩니다.

### Config

In [None]:
args.alg = 'ICT'
ict_config = {
    # interpolation consistency training
    "ema_factor" : 0.999,
    "lr" : 4e-4,
    "consis_coef" : 100,
    "alpha" : 0.1,
}
alg_cfg = ict_config

### Model

In [None]:
class ICT(nn.Module):
    def __init__(self, alpha, model, ema_factor):
        super().__init__()
        self.alpha = alpha
        self.mean_teacher = model
        self.mean_teacher.train()
        self.ema_factor = ema_factor
        self.global_step = 0

    def forward(self, x, y, model, mask):
        self.global_step += 1 # for moving average coef
        mask = mask.byte()
        model.update_batch_stats(False)
        mt_y = self.mean_teacher(x).detach()
        u_x, u_y = x[mask], mt_y[mask]
        l_x, l_y = x[mask==0], mt_y[mask==0]
        lam = np.random.beta(self.alpha, self.alpha) # sample mixup coef
        perm = torch.randperm(u_x.shape[0])
        perm_u_x, perm_u_y = u_x[perm], u_y[perm]
        mixed_u_x = lam * u_x + (1 - lam) * perm_u_x
        mixed_u_y = (lam * u_y + (1 - lam) * perm_u_y).detach()
        y_hat = model(torch.cat([l_x, mixed_u_x], 0)) # "cat" indicates to compute batch stats from full batches
        loss = F.mse_loss(y_hat.softmax(1), torch.cat([l_y, mixed_u_y], 0).softmax(1), reduction="none").sum(1)
        # compute loss for only unlabeled data, but loss is normalized by full batchsize
        loss = loss[l_x.shape[0]:].sum() / x.shape[0]
        model.update_batch_stats(True)
        return loss

    def moving_average(self, parameters):
        ema_factor = min(1 - 1 / (self.global_step), self.ema_factor)
        for emp_p, p in zip(self.mean_teacher.parameters(), parameters):
            emp_p.data = ema_factor * emp_p.data + (1 - ema_factor) * p.data


In [None]:
model = wrn.WRN(2, dataset_cfg["num_classes"], transform_fn).to(device)
optimizer = optim.Adam(model.parameters(), lr=alg_cfg["lr"])

trainable_paramters = sum([p.data.nelement() for p in model.parameters()])
print("trainable parameters : {}".format(trainable_paramters))

In [None]:
t_model = wrn.WRN(2, dataset_cfg["num_classes"], transform_fn).to(device)
t_model.load_state_dict(model.state_dict())
ssl_obj = ICT(alg_cfg["alpha"], t_model, alg_cfg["ema_factor"])

### Train

In [None]:
###
ict_cls_loss_list = []
ict_ssl_loss_list = []
ict_loss_list = []

ict_val_acc_list = []
###

iteration = 0
maximum_val_acc = 0
s = time.time()
for l_data, u_data in zip(l_loader, u_loader):
    iteration += 1
    l_input, target = l_data
    l_input, target = l_input.to(device).float(), target.to(device).long()

    u_input, dummy_target = u_data
    u_input, dummy_target = u_input.to(device).float(), dummy_target.to(device).long()

    target = torch.cat([target, dummy_target], 0)
    unlabeled_mask = (target == -1).float()

    inputs = torch.cat([l_input, u_input], 0)
    outputs = model(inputs)

    # ramp up exp(-5(1 - t)^2)
    coef = alg_cfg["consis_coef"] * math.exp(-5 * (1 - min(iteration/shared_config["warmup"], 1))**2)
    ssl_loss = ssl_obj(inputs, outputs.detach(), model, unlabeled_mask) * coef

    # supervised loss
    cls_loss = F.cross_entropy(outputs, target, reduction="none", ignore_index=-1).mean()

    loss = cls_loss + ssl_loss
    
    ###
    ict_cls_loss_list.append(cls_loss)
    ict_ssl_loss_list.append(ssl_loss)
    ict_loss_list.append(loss)
    ###

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    # 표기
    if iteration == 1 or (iteration % 100) == 0:
        wasted_time = time.time() - s
        rest = (shared_config["iteration"] - iteration)/100 * wasted_time / 60
        print("iteration [{}/{}] cls loss : {:.6e}, SSL loss : {:.6e}, coef : {:.5e}, time : {:.3f} iter/sec, rest : {:.3f} min, lr : {}".format(
            iteration, shared_config["iteration"], cls_loss.item(), ssl_loss.item(), coef, 100 / wasted_time, rest, optimizer.param_groups[0]["lr"]),
            "\r", end="")
        s = time.time()
        
    # Validation
    if (iteration % args.validation) == 0 or iteration == shared_config["iteration"]:
        with torch.no_grad():
            model.eval()
            print()
            print("### validation ###")
            sum_acc = 0.
            s = time.time()
            for j, data in enumerate(val_loader):
                input, target = data
                input, target = input.to(device).float(), target.to(device).long()

                output = model(input)

                pred_label = output.max(1)[1]
                sum_acc += (pred_label == target).float().sum()
                if ((j+1) % 10) == 0:
                    d_p_s = 10/(time.time()-s)
                    print("[{}/{}] time : {:.1f} data/sec, rest : {:.2f} sec".format(
                        j+1, len(val_loader), d_p_s, (len(val_loader) - j-1)/d_p_s
                    ), "\r", end="")
                    s = time.time()
            acc = sum_acc/float(len(val_dataset))
            print()
            print("varidation accuracy : {}".format(acc))
            
            ###
            ict_val_acc_list.append(acc)
            ###
            
            # Test
            if maximum_val_acc < acc:
                print("### test ###")
                maximum_val_acc = acc
                sum_acc = 0.
                s = time.time()
                for j, data in enumerate(test_loader):
                    input, target = data
                    input, target = input.to(device).float(), target.to(device).long()
                    output = model(input)
                    pred_label = output.max(1)[1]
                    sum_acc += (pred_label == target).float().sum()
                    if ((j+1) % 10) == 0:
                        d_p_s = 100/(time.time()-s)
                        print("[{}/{}] time : {:.1f} data/sec, rest : {:.2f} sec".format(
                            j+1, len(test_loader), d_p_s, (len(test_loader) - j-1)/d_p_s
                        ), "\r", end="")
                        s = time.time()
                print()
                test_acc = sum_acc / float(len(test_dataset))
                print("test accuracy : {}".format(test_acc))
                # torch.save(model.state_dict(), os.path.join(args.output, "best_model.pth"))
        model.train()
        s = time.time()
        
    # lr decay
    if iteration == shared_config["lr_decay_iter"]:
        optimizer.param_groups[0]["lr"] *= shared_config["lr_decay_factor"]    
        
print("test acc : {}".format(test_acc))

## 5. MixMatch

<img src="https://user-images.githubusercontent.com/35906602/209689424-e0675460-c8ac-473e-9286-1713a9f8cf28.png" width="900">

* Reference : https://jiwunghyun.medium.com/semi-supervised-learning-%EC%A0%95%EB%A6%AC-a7ed58a8f023
* David Berthelot et al. MixMatch: A Holistic Approach for Semi-Supervised Learning. NeurIPS 2019.

앞에 나온 entropy minimization, label consistency regularization, mixup을 모두 적용한 방법입니다. MixMatch는 labeled data와 unlabeled data를 받아서 결합된 데이터를 만듭니다.

Unlabeled data에 대하여 K번의 augmentation을 하고 prediction의 평균을 구하고 그 값을 temperature sharpening을 통하여 sharpen 하며, Augmentation된 labeled, unlabeled 데이터를 섞고, 그 데이터에 대하여 labeled data와 unlabeled data에 MixUp을 합니다.

학습은 다른 모델과 같이 supervised loss는 CE, unsupervised loss는 모델 출력 값의 차이 (L2)가 됩니다.

### Config

In [None]:
args.alg = 'MM'
mm_config = {
    # mixmatch
    "lr" : 3e-3,
    "consis_coef" : 100,
    "alpha" : 0.75,
    "T" : 0.5,
    "K" : 2,
}
alg_cfg = mm_config

### Model

In [None]:
class MixMatch(nn.Module):
    def __init__(self, temperature, n_augment, alpha):
        super().__init__()
        self.T = temperature
        self.K = n_augment
        self.beta_distirb = torch.distributions.beta.Beta(alpha, alpha)

    def sharpen(self, y):
        y = y.pow(1/self.T)
        return y / y.sum(1,keepdim=True)

    def forward(self, x, y, model, mask):
        # NOTE: this implementaion uses mixup for only unlabeled data
        model.update_batch_stats(False)
        u_x = x[mask == 1]
        # K augmentation and make prediction labels
        u_x_hat = [u_x for _ in range(self.K)]
        y_hat = sum([model(u_x_hat[i]).softmax(1) for i in range(len(u_x_hat))]) / self.K
        y_hat = self.sharpen(y_hat)
        y_hat = y_hat.repeat(len(u_x_hat), 1)
        # mixup
        u_x_hat = torch.cat(u_x_hat, 0)
        index = torch.randperm(u_x_hat.shape[0])
        shuffled_u_x_hat, shuffled_y_hat = u_x_hat[index], y_hat[index]
        lam = self.beta_distirb.sample().item()
        # lam = max(lam, 1-lam)
        mixed_x = lam * u_x_hat + (1-lam) * shuffled_u_x_hat
        mixed_y = lam * y_hat + (1-lam) * shuffled_y_hat.softmax(1)
        # mean squared error
        loss = F.mse_loss(model(mixed_x), mixed_y)
        model.update_batch_stats(True)
        return loss

In [None]:
model = wrn.WRN(2, dataset_cfg["num_classes"], transform_fn).to(device)
optimizer = optim.Adam(model.parameters(), lr=alg_cfg["lr"])

trainable_paramters = sum([p.data.nelement() for p in model.parameters()])
print("trainable parameters : {}".format(trainable_paramters))

In [None]:
ssl_obj = MixMatch(alg_cfg["T"], alg_cfg["K"], alg_cfg["alpha"])

### Train

In [None]:
###
mm_cls_loss_list = []
mm_ssl_loss_list = []
mm_loss_list = []

mm_val_acc_list = []
###

iteration = 0
maximum_val_acc = 0
s = time.time()
for l_data, u_data in zip(l_loader, u_loader):
    iteration += 1
    l_input, target = l_data
    l_input, target = l_input.to(device).float(), target.to(device).long()

    u_input, dummy_target = u_data
    u_input, dummy_target = u_input.to(device).float(), dummy_target.to(device).long()

    target = torch.cat([target, dummy_target], 0)
    unlabeled_mask = (target == -1).float()

    inputs = torch.cat([l_input, u_input], 0)
    outputs = model(inputs)

    # ramp up exp(-5(1 - t)^2)
    coef = alg_cfg["consis_coef"] * math.exp(-5 * (1 - min(iteration/shared_config["warmup"], 1))**2)
    ssl_loss = ssl_obj(inputs, outputs.detach(), model, unlabeled_mask) * coef

    # supervised loss
    cls_loss = F.cross_entropy(outputs, target, reduction="none", ignore_index=-1).mean()

    loss = cls_loss + ssl_loss
    
    ###
    mm_cls_loss_list.append(cls_loss)
    mm_ssl_loss_list.append(ssl_loss)
    mm_loss_list.append(loss)
    ###

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    # 표기
    if iteration == 1 or (iteration % 100) == 0:
        wasted_time = time.time() - s
        rest = (shared_config["iteration"] - iteration)/100 * wasted_time / 60
        print("iteration [{}/{}] cls loss : {:.6e}, SSL loss : {:.6e}, coef : {:.5e}, time : {:.3f} iter/sec, rest : {:.3f} min, lr : {}".format(
            iteration, shared_config["iteration"], cls_loss.item(), ssl_loss.item(), coef, 100 / wasted_time, rest, optimizer.param_groups[0]["lr"]),
            "\r", end="")
        s = time.time()
        
    # Validation
    if (iteration % args.validation) == 0 or iteration == shared_config["iteration"]:
        with torch.no_grad():
            model.eval()
            print()
            print("### validation ###")
            sum_acc = 0.
            s = time.time()
            for j, data in enumerate(val_loader):
                input, target = data
                input, target = input.to(device).float(), target.to(device).long()

                output = model(input)

                pred_label = output.max(1)[1]
                sum_acc += (pred_label == target).float().sum()
                if ((j+1) % 10) == 0:
                    d_p_s = 10/(time.time()-s)
                    print("[{}/{}] time : {:.1f} data/sec, rest : {:.2f} sec".format(
                        j+1, len(val_loader), d_p_s, (len(val_loader) - j-1)/d_p_s
                    ), "\r", end="")
                    s = time.time()
            acc = sum_acc/float(len(val_dataset))
            print()
            print("varidation accuracy : {}".format(acc))
            
            ###
            mm_val_acc_list.append(acc)
            ###
            
            # Test
            if maximum_val_acc < acc:
                print("### test ###")
                maximum_val_acc = acc
                sum_acc = 0.
                s = time.time()
                for j, data in enumerate(test_loader):
                    input, target = data
                    input, target = input.to(device).float(), target.to(device).long()
                    output = model(input)
                    pred_label = output.max(1)[1]
                    sum_acc += (pred_label == target).float().sum()
                    if ((j+1) % 10) == 0:
                        d_p_s = 100/(time.time()-s)
                        print("[{}/{}] time : {:.1f} data/sec, rest : {:.2f} sec".format(
                            j+1, len(test_loader), d_p_s, (len(test_loader) - j-1)/d_p_s
                        ), "\r", end="")
                        s = time.time()
                print()
                test_acc = sum_acc / float(len(test_dataset))
                print("test accuracy : {}".format(test_acc))
                # torch.save(model.state_dict(), os.path.join(args.output, "best_model.pth"))
        model.train()
        s = time.time()
        
    # lr decay
    if iteration == shared_config["lr_decay_iter"]:
        optimizer.param_groups[0]["lr"] *= shared_config["lr_decay_factor"]    
        
print("test acc : {}".format(test_acc))