In [None]:
*Semi-Supervised Learning: Supervised-learning(지도학습)과 Unsupervised-learning(비지도학습)의 혼합된 방식임
- 소수의 labeled data가 존재하고, 다수의 unlabeled data가 존재할 때 사용함
- 이때, labeled data에 대해선 supervised loss를 사용하나
- unlabeled data에 대해선 unsuperivesed loss를 사용함 
--> unlabeled data와 그 데이터의 변형된 값에 의해 산출된 output들의 차이가 작은 모델 구축
- 두가지 대표적인 기법 존재: Consistency Regularization(일관성 관점), Holistic Methods(종합적인 관점)

<Consistency Regularization 계열 모델>
1. Pie-Model(2016)
2. Temporal Ensemble(2016)
3. Virtual Adversarial Training(2017)
4. Mean Teacher(2017)

*위 4가지의 모델을 동일한 조건에서, 두 개의 데이터 셋을 놓고 5번의 비교실험을 진행함

## Dataset: CIFAR-10 and MNIST

In [None]:
*데이터 셋은 CIFAR-10과 MNIST데이터를 사용함
*CIFAR-10: CLASS수는 총 10가지인 (3,32,32)차원의 '컬러'이미지 데이터임.
    60000개(학습용 50000, 테스트용 10000)의 데이터셋으로 구성됨
*MNIST: 0~9까지의 숫자를 손글씨로 적은 (1,28,28)차원의 '흑백'이미지 데이터임(0~9까지의 손글씨이므로 CLASS수도 당연히 10개임),
    70000개(학습용 60000, 테스트용 10000)의 데이터셋으로 구성됨

## CIFAR-10 데이터셋을 활용하여 비교

## Pi-Model(2016) vs Mean Teacher(2017) vs VAT(2017)

In [None]:
*Pi-Model(2016):
    2015년 출시된 Ladder Network에선 Layer-wise latent vector들의 consistency를 고려하였다면,
    파이모델에선, latent vector가 아닌 Output vector들의 consistency를 고려함
    
    하나의 FFN(Feed-Forward Neural Network)에 2번의 Perturbation(변형)을 적용함
    - Supervised loss: Cross Entropy
    - Unsupervised loss: MSE 
    - Total loss = Cross Entropy + w*MSE

In [None]:
* Mean Teacher(2017): 
    새로 학습된 정보는 각 epoch당 한 번만 업데이트되기 때문에 느린 속도로 학습에 반영됨
    
    파이모델 에선, 같은 모델(구조)이 teacher와 student의 역할을 모두 감당함
    --> 오분류될 확률이 높음
    따라서, 파이모델과 다르게 target의 quality가 개선되어야 함!
    
    -개선 방법: perturbations을 신중히 함 or teacher model을 student와 다른 모델을 사용

In [None]:
* Virtual Adversarial Training(2017):
    적대적 학습(Adversarial training)기법을 활용해 모델이 가장 취약한 방향으로 학습
    --> 모델의 강건성을 높임
    
    원본이미지와 적대적학습 이미지의 loss값을 통하여 학습함

In [None]:
*출처: https://github.com/perrying/realistic-ssl-evaluation-pytorch

#### dataset 구축

In [17]:
os.getcwd()

'E:\\coursework\\ba\\2022_BA_donghwanshin\\05. Semi-supervised Learning'

In [2]:
from torchvision import datasets
import argparse, os
import numpy as np

parser = argparse.ArgumentParser()
parser.add_argument("--seed", "-s", default=1, type=int, help="random seed")
parser.add_argument("--dataset", "-d", default="cifar10", type=str, help="dataset name : [svhn, cifar10]")
parser.add_argument("--nlabels", "-n", default=1000, type=int, help="the number of labeled data")
args, _ = parser.parse_known_args()

COUNTS = {
    "svhn": {"train": 73257, "test": 26032, "valid": 7326, "extra": 531131},
    "cifar10": {"train": 50000, "test": 10000, "valid": 5000, "extra": 0},
    "imagenet_32": {
        "train": 1281167,
        "test": 50000,
        "valid": 50050,
        "extra": 0,
    },
}

_DATA_DIR = "./data"

def split_l_u(train_set, n_labels):
    # NOTE: this function assume that train_set is shuffled.
    images = train_set["images"]
    labels = train_set["labels"]
    classes = np.unique(labels)
    n_labels_per_cls = n_labels // len(classes)
    l_images = []
    l_labels = []
    u_images = []
    u_labels = []
    for c in classes:
        cls_mask = (labels == c)
        c_images = images[cls_mask]
        c_labels = labels[cls_mask]
        l_images += [c_images[:n_labels_per_cls]]
        l_labels += [c_labels[:n_labels_per_cls]]
        u_images += [c_images[n_labels_per_cls:]]
        u_labels += [np.zeros_like(c_labels[n_labels_per_cls:]) - 1] # dammy label
    l_train_set = {"images": np.concatenate(l_images, 0), "labels": np.concatenate(l_labels, 0)}
    u_train_set = {"images": np.concatenate(u_images, 0), "labels": np.concatenate(u_labels, 0)}
    return l_train_set, u_train_set

def _load_svhn():
    splits = {}
    for split in ["train", "test", "extra"]:
        tv_data = datasets.SVHN(_DATA_DIR, split, download=True)
        data = {}
        data["images"] = tv_data.data
        data["labels"] = tv_data.labels
        splits[split] = data
    return splits.values()

def _load_cifar10():
    splits = {}
    for train in [True, False]:
        tv_data = datasets.CIFAR10(_DATA_DIR, train, download=True)
        data = {}
        data["images"] = tv_data.data
        data["labels"] = np.array(tv_data.targets)
        splits["train" if train else "test"] = data
    return splits.values()

def gcn(images, multiplier=55, eps=1e-10):
    # global contrast normalization
    images = images.astype(np.float)
    images -= images.mean(axis=(1,2,3), keepdims=True)
    per_image_norm = np.sqrt(np.square(images).sum((1,2,3), keepdims=True))
    per_image_norm[per_image_norm < eps] = 1
    return multiplier * images / per_image_norm

def get_zca_normalization_param(images, scale=0.1, eps=1e-10):
    n_data, height, width, channels = images.shape
    images = images.reshape(n_data, height*width*channels)
    image_cov = np.cov(images, rowvar=False)
    U, S, _ = np.linalg.svd(image_cov + scale * np.eye(image_cov.shape[0]))
    zca_decomp = np.dot(U, np.dot(np.diag(1/np.sqrt(S + eps)), U.T))
    mean = images.mean(axis=0)
    return mean, zca_decomp

def zca_normalization(images, mean, decomp):
    n_data, height, width, channels = images.shape
    images = images.reshape(n_data, -1)
    images = np.dot((images - mean), decomp)
    return images.reshape(n_data, height, width, channels)

rng = np.random.RandomState(args.seed)

validation_count = COUNTS[args.dataset]["valid"]

extra_set = None  # In general, there won't be extra data.
if args.dataset == "svhn":
    train_set, test_set, extra_set = _load_svhn()
elif args.dataset == "cifar10":
    train_set, test_set = _load_cifar10()
    train_set["images"] = gcn(train_set["images"])
    test_set["images"] = gcn(test_set["images"])
    mean, zca_decomp = get_zca_normalization_param(train_set["images"])
    train_set["images"] = zca_normalization(train_set["images"], mean, zca_decomp)
    test_set["images"] = zca_normalization(test_set["images"], mean, zca_decomp)
    # N x H x W x C -> N x C x H x W
    train_set["images"] = np.transpose(train_set["images"], (0,3,1,2))
    test_set["images"] = np.transpose(test_set["images"], (0,3,1,2))

# permute index of training set
indices = rng.permutation(len(train_set["images"]))
train_set["images"] = train_set["images"][indices]
train_set["labels"] = train_set["labels"][indices]

if extra_set is not None:
    extra_indices = rng.permutation(len(extra_set["images"]))
    extra_set["images"] = extra_set["images"][extra_indices]
    extra_set["labels"] = extra_set["labels"][extra_indices]

# split training set into training and validation
train_images = train_set["images"][validation_count:]
train_labels = train_set["labels"][validation_count:]
validation_images = train_set["images"][:validation_count]
validation_labels = train_set["labels"][:validation_count]
validation_set = {"images": validation_images, "labels": validation_labels}
train_set = {"images": train_images, "labels": train_labels}

# split training set into labeled data and unlabeled data
l_train_set, u_train_set = split_l_u(train_set, args.nlabels)

if not os.path.exists(os.path.join(_DATA_DIR, args.dataset)):
    os.mkdir(os.path.join(_DATA_DIR, args.dataset))

np.save(os.path.join(_DATA_DIR, args.dataset, "l_train"), l_train_set)
np.save(os.path.join(_DATA_DIR, args.dataset, "u_train"), u_train_set)
np.save(os.path.join(_DATA_DIR, args.dataset, "val"), validation_set)
np.save(os.path.join(_DATA_DIR, args.dataset, "test"), test_set)
if extra_set is not None:
    np.save(os.path.join(_DATA_DIR, args.dataset, "extra"), extra_set)

Files already downloaded and verified
Files already downloaded and verified


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations


#### 데이터 및 모델 로드

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

import argparse, math, time, json, os

import wrn, transform
from config import config

parser = argparse.ArgumentParser()
parser.add_argument("--alg", "-a", default="PI", type=str, help="ssl algorithm : [supervised, PI, MT, VAT, PL, ICT]")
parser.add_argument("--em", default=0, type=float, help="coefficient of entropy minimization. If you try VAT + EM, set 0.06")
parser.add_argument("--validation", default=25000, type=int, help="validate at this interval (default 25000)")
parser.add_argument("--dataset", "-d", default="svhn", type=str, help="dataset name : [svhn, cifar10]")
parser.add_argument("--root", "-r", default="data", type=str, help="dataset dir")
parser.add_argument("--output", "-o", default="./exp_res", type=str, help="output dir")
args, _ = parser.parse_known_args()

In [4]:
device = 'cuda:1' if torch.cuda.is_available() else 'cpu'


condition = {}
exp_name = ""
args.dataset = 'cifar10'
condition["dataset"] = args.dataset
exp_name += str(args.dataset) + "_"
dataset_cfg = config[args.dataset]
transform_fn = transform.transform(*dataset_cfg["transform"]) # transform function (flip, crop, noise)


l_train_dataset = dataset_cfg["dataset"](args.root, "l_train")
u_train_dataset = dataset_cfg["dataset"](args.root, "u_train")
val_dataset = dataset_cfg["dataset"](args.root, "val")
test_dataset = dataset_cfg["dataset"](args.root, "test")

print("labeled data : {}, unlabeled data : {}, training data : {}".format(
    len(l_train_dataset), len(u_train_dataset), len(l_train_dataset)+len(u_train_dataset)))
print("validation data : {}, test data : {}".format(len(val_dataset), len(test_dataset)))
condition["number_of_data"] = {
    "labeled":len(l_train_dataset), "unlabeled":len(u_train_dataset),
    "validation":len(val_dataset), "test":len(test_dataset)
}

holizontal flip : True, random crop : True, gaussian noise : True
labeled data : 1000, unlabeled data : 44000, training data : 45000
validation data : 5000, test data : 10000


In [5]:
config["shared"]

{'iteration': 8800,
 'warmup': 1000,
 'lr_decay_iter': 400,
 'lr_decay_factor': 0.2,
 'batch_size': 256}

In [6]:
class RandomSampler(torch.utils.data.Sampler):
    """ sampling without replacement """
    def __init__(self, num_data, num_sample):
        iterations = num_sample // num_data + 1
        self.indices = torch.cat([torch.randperm(num_data) for _ in range(iterations)]).tolist()[:num_sample]

    def __iter__(self):
        return iter(self.indices)

    def __len__(self):
        return len(self.indices)

shared_cfg = config["shared"]

### Pi-model

In [7]:
args.alg = 'PI'

In [8]:
if args.alg != "supervised":
    # batch size = 0.5 x batch size
    l_loader = DataLoader(
        l_train_dataset, shared_cfg["batch_size"]//2, drop_last=True,
        sampler=RandomSampler(len(l_train_dataset), shared_cfg["iteration"] * shared_cfg["batch_size"]//2)
    )
else:
    l_loader = DataLoader(
        l_train_dataset, shared_cfg["batch_size"], drop_last=True,
        sampler=RandomSampler(len(l_train_dataset), shared_cfg["iteration"] * shared_cfg["batch_size"])
    )
print("algorithm : {}".format(args.alg))
condition["algorithm"] = args.alg
exp_name += str(args.alg) + "_"

u_loader = DataLoader(
    u_train_dataset, shared_cfg["batch_size"]//2, drop_last=True,
    sampler=RandomSampler(len(u_train_dataset), shared_cfg["iteration"] * shared_cfg["batch_size"]//2)
)

val_loader = DataLoader(val_dataset, 128, shuffle=False, drop_last=False)
test_loader = DataLoader(test_dataset, 128, shuffle=False, drop_last=False)

print("maximum iteration : {}".format(min(len(l_loader), len(u_loader))))

alg_cfg = config[args.alg]
print("parameters : ", alg_cfg)
condition["h_parameters"] = alg_cfg

if args.em > 0:
    print("entropy minimization : {}".format(args.em))
    exp_name += "em_"
condition["entropy_maximization"] = args.em

model = wrn.WRN(2, dataset_cfg["num_classes"], transform_fn).to(device)
optimizer = optim.Adam(model.parameters(), lr=alg_cfg["lr"])

trainable_paramters = sum([p.data.nelement() for p in model.parameters()])
print("trainable parameters : {}".format(trainable_paramters))

algorithm : PI
maximum iteration : 8800
parameters :  {'lr': 0.0003, 'consis_coef': 20.0}
trainable parameters : 1467610


In [9]:
if args.alg == "VAT": # virtual adversarial training
    from vat import VAT
    ssl_obj = VAT(alg_cfg["eps"][args.dataset], alg_cfg["xi"], 1)
elif args.alg == "MT": # mean teacher
    from mean_teacher import MT
    t_model = wrn.WRN(2, dataset_cfg["num_classes"], transform_fn).to(device)
    t_model.load_state_dict(model.state_dict())
    ssl_obj = MT(t_model, alg_cfg["ema_factor"])
elif args.alg == "PI": # PI Model
    from pimodel import PiModel
    ssl_obj = PiModel()
else:
    raise ValueError("{} is unknown algorithm".format(args.alg))

print()
iteration = 0
maximum_val_acc = 0
s = time.time()
for l_data, u_data in zip(l_loader, u_loader):
    iteration += 1
    l_input, target = l_data
    l_input, target = l_input.to(device).float(), target.to(device).long()

    if args.alg != "supervised": # for ssl algorithm
        u_input, dummy_target = u_data
        u_input, dummy_target = u_input.to(device).float(), dummy_target.to(device).long()

        target = torch.cat([target, dummy_target], 0)
        unlabeled_mask = (target == -1).float()

        inputs = torch.cat([l_input, u_input], 0)
        outputs = model(inputs)

        # ramp up exp(-5(1 - t)^2)
        coef = alg_cfg["consis_coef"] * math.exp(-5 * (1 - min(iteration/shared_cfg["warmup"], 1))**2)
        ssl_loss = ssl_obj(inputs, outputs.detach(), model, unlabeled_mask) * coef

    else:
        outputs = model(l_input)
        coef = 0
        ssl_loss = torch.zeros(1).to(device)

    # supervised loss
    cls_loss = F.cross_entropy(outputs, target, reduction="none", ignore_index=-1).mean()

    loss = cls_loss + ssl_loss

    if args.em > 0:
        loss -= args.em * ((outputs.softmax(1) * F.log_softmax(outputs, 1)).sum(1) * unlabeled_mask).mean()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if args.alg == "MT" or args.alg == "ICT":
        # parameter update with exponential moving average
        ssl_obj.moving_average(model.parameters())
    # display
    if iteration == 1 or (iteration % 1000) == 0:
        wasted_time = time.time() - s
        rest = (shared_cfg["iteration"] - iteration)/100 * wasted_time / 60
        print("iteration [{}/{}] cls loss : {:.6e}, SSL loss : {:.6e}, coef : {:.5e}, time : {:.3f} iter/sec, rest : {:.3f} min, lr : {}".format(
            iteration, shared_cfg["iteration"], cls_loss.item(), ssl_loss.item(), coef, 100 / wasted_time, rest, optimizer.param_groups[0]["lr"]),
            "\r")
        s = time.time()

    # validation
    if (iteration % args.validation) == 0 or iteration == shared_cfg["iteration"]:
        with torch.no_grad():
            model.eval()
            print()
            print("### validation ###")
            sum_acc = 0.
            s = time.time()
            for j, data in enumerate(val_loader):
                input, target = data
                input, target = input.to(device).float(), target.to(device).long()

                output = model(input)

                pred_label = output.max(1)[1]
                sum_acc += (pred_label == target).float().sum()
                if ((j+1) % 10) == 0:
                    d_p_s = 10/(time.time()-s)
                    print("[{}/{}] time : {:.1f} data/sec, rest : {:.2f} sec".format(
                        j+1, len(val_loader), d_p_s, (len(val_loader) - j-1)/d_p_s
                    ), "\r", end="")
                    s = time.time()
            acc = sum_acc/float(len(val_dataset))
            print()
            print("varidation accuracy : {}".format(acc))
            # test
            if maximum_val_acc < acc:
                print("### test ###")
                maximum_val_acc = acc
                sum_acc = 0.
                s = time.time()
                for j, data in enumerate(test_loader):
                    input, target = data
                    input, target = input.to(device).float(), target.to(device).long()
                    output = model(input)
                    pred_label = output.max(1)[1]
                    sum_acc += (pred_label == target).float().sum()
                    if ((j+1) % 10) == 0:
                        d_p_s = 100/(time.time()-s)
                        print("[{}/{}] time : {:.1f} data/sec, rest : {:.2f} sec".format(
                            j+1, len(test_loader), d_p_s, (len(test_loader) - j-1)/d_p_s
                        ), "\r", end="")
                        s = time.time()
                print()
                test_acc = sum_acc / float(len(test_dataset))
                print("test accuracy : {}".format(test_acc))
                # torch.save(model.state_dict(), os.path.join(args.output, "best_model.pth"))
        model.train()
        s = time.time()
    # lr decay
    if iteration == shared_cfg["lr_decay_iter"]:
        optimizer.param_groups[0]["lr"] *= shared_cfg["lr_decay_factor"]

print("test acc : {}".format(test_acc))
condition["test_acc"] = test_acc.item()

exp_name += str(int(time.time())) # unique ID
if not os.path.exists(args.output):
    os.mkdir(args.output)
with open(os.path.join(args.output, exp_name + ".json"), "w") as f:
    json.dump(condition, f)


iteration [1/8800] cls loss : 1.193980e+00, SSL loss : 1.643282e-05, coef : 1.36113e-01, time : 73.715 iter/sec, rest : 1.989 min, lr : 0.0003 
iteration [1000/8800] cls loss : 6.385348e-02, SSL loss : 1.389795e-01, coef : 2.00000e+01, time : 0.890 iter/sec, rest : 146.143 min, lr : 5.9999999999999995e-05 
iteration [2000/8800] cls loss : 3.971205e-02, SSL loss : 1.439586e-01, coef : 2.00000e+01, time : 0.885 iter/sec, rest : 127.999 min, lr : 5.9999999999999995e-05 
iteration [3000/8800] cls loss : 2.138251e-02, SSL loss : 1.379075e-01, coef : 2.00000e+01, time : 0.886 iter/sec, rest : 109.090 min, lr : 5.9999999999999995e-05 
iteration [4000/8800] cls loss : 2.173064e-02, SSL loss : 1.026384e-01, coef : 2.00000e+01, time : 0.885 iter/sec, rest : 90.377 min, lr : 5.9999999999999995e-05 
iteration [5000/8800] cls loss : 1.885763e-02, SSL loss : 5.245686e-02, coef : 2.00000e+01, time : 0.884 iter/sec, rest : 71.618 min, lr : 5.9999999999999995e-05 
iteration [6000/8800] cls loss : 2.21

In [None]:
* 결과해석:
    파이 모델을 CIFAR-10의 50000개의 데이터를 256의 배치 사이즈로 8800 iteration 훈련시킨 결과 
    "supervised loss : 0.02", "unsupervised loss: 0.08"로 수렴함
    
    테스트 결과는 0.5914임

### Mean Teacher(MT)

In [13]:
args.alg = 'MT'

In [14]:
if args.alg != "supervised":
    # batch size = 0.5 x batch size
    l_loader = DataLoader(
        l_train_dataset, shared_cfg["batch_size"]//2, drop_last=True,
        sampler=RandomSampler(len(l_train_dataset), shared_cfg["iteration"] * shared_cfg["batch_size"]//2)
    )
else:
    l_loader = DataLoader(
        l_train_dataset, shared_cfg["batch_size"], drop_last=True,
        sampler=RandomSampler(len(l_train_dataset), shared_cfg["iteration"] * shared_cfg["batch_size"])
    )
print("algorithm : {}".format(args.alg))
condition["algorithm"] = args.alg
exp_name += str(args.alg) + "_"

u_loader = DataLoader(
    u_train_dataset, shared_cfg["batch_size"]//2, drop_last=True,
    sampler=RandomSampler(len(u_train_dataset), shared_cfg["iteration"] * shared_cfg["batch_size"]//2)
)

val_loader = DataLoader(val_dataset, 128, shuffle=False, drop_last=False)
test_loader = DataLoader(test_dataset, 128, shuffle=False, drop_last=False)

print("maximum iteration : {}".format(min(len(l_loader), len(u_loader))))

alg_cfg = config[args.alg]
print("parameters : ", alg_cfg)
condition["h_parameters"] = alg_cfg

if args.em > 0:
    print("entropy minimization : {}".format(args.em))
    exp_name += "em_"
condition["entropy_maximization"] = args.em

model = wrn.WRN(2, dataset_cfg["num_classes"], transform_fn).to(device)
optimizer = optim.Adam(model.parameters(), lr=alg_cfg["lr"])

trainable_paramters = sum([p.data.nelement() for p in model.parameters()])
print("trainable parameters : {}".format(trainable_paramters))

algorithm : MT
maximum iteration : 8800
parameters :  {'ema_factor': 0.95, 'lr': 0.0004, 'consis_coef': 8}
trainable parameters : 1467610


In [15]:
if args.alg == "VAT": # virtual adversarial training
    from vat import VAT
    ssl_obj = VAT(alg_cfg["eps"][args.dataset], alg_cfg["xi"], 1)
elif args.alg == "MT": # mean teacher
    from mean_teacher import MT
    t_model = wrn.WRN(2, dataset_cfg["num_classes"], transform_fn).to(device)
    t_model.load_state_dict(model.state_dict())
    ssl_obj = MT(t_model, alg_cfg["ema_factor"])
elif args.alg == "PI": # PI Model
    from pimodel import PiModel
    ssl_obj = PiModel()
else:
    raise ValueError("{} is unknown algorithm".format(args.alg))

print()
iteration = 0
maximum_val_acc = 0
s = time.time()
for l_data, u_data in zip(l_loader, u_loader):
    iteration += 1
    l_input, target = l_data
    l_input, target = l_input.to(device).float(), target.to(device).long()

    if args.alg != "supervised": # for ssl algorithm
        u_input, dummy_target = u_data
        u_input, dummy_target = u_input.to(device).float(), dummy_target.to(device).long()

        target = torch.cat([target, dummy_target], 0)
        unlabeled_mask = (target == -1).float()

        inputs = torch.cat([l_input, u_input], 0)
        outputs = model(inputs)

        # ramp up exp(-5(1 - t)^2)
        coef = alg_cfg["consis_coef"] * math.exp(-5 * (1 - min(iteration/shared_cfg["warmup"], 1))**2)
        ssl_loss = ssl_obj(inputs, outputs.detach(), model, unlabeled_mask) * coef

    else:
        outputs = model(l_input)
        coef = 0
        ssl_loss = torch.zeros(1).to(device)

    # supervised loss
    cls_loss = F.cross_entropy(outputs, target, reduction="none", ignore_index=-1).mean()

    loss = cls_loss + ssl_loss

    if args.em > 0:
        loss -= args.em * ((outputs.softmax(1) * F.log_softmax(outputs, 1)).sum(1) * unlabeled_mask).mean()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if args.alg == "MT" or args.alg == "ICT":
        # parameter update with exponential moving average
        ssl_obj.moving_average(model.parameters())
    # display
    if iteration == 1 or (iteration % 1000) == 0:
        wasted_time = time.time() - s
        rest = (shared_cfg["iteration"] - iteration)/100 * wasted_time / 60
        print("iteration [{}/{}] cls loss : {:.6e}, SSL loss : {:.6e}, coef : {:.5e}, time : {:.3f} iter/sec, rest : {:.3f} min, lr : {}".format(
            iteration, shared_cfg["iteration"], cls_loss.item(), ssl_loss.item(), coef, 100 / wasted_time, rest, optimizer.param_groups[0]["lr"]),
            "\r")
        s = time.time()

    # validation
    if (iteration % args.validation) == 0 or iteration == shared_cfg["iteration"]:
        with torch.no_grad():
            model.eval()
            print()
            print("### validation ###")
            sum_acc = 0.
            s = time.time()
            for j, data in enumerate(val_loader):
                input, target = data
                input, target = input.to(device).float(), target.to(device).long()

                output = model(input)

                pred_label = output.max(1)[1]
                sum_acc += (pred_label == target).float().sum()
                if ((j+1) % 10) == 0:
                    d_p_s = 10/(time.time()-s)
                    print("[{}/{}] time : {:.1f} data/sec, rest : {:.2f} sec".format(
                        j+1, len(val_loader), d_p_s, (len(val_loader) - j-1)/d_p_s
                    ), "\r", end="")
                    s = time.time()
            acc = sum_acc/float(len(val_dataset))
            print()
            print("varidation accuracy : {}".format(acc))
            # test
            if maximum_val_acc < acc:
                print("### test ###")
                maximum_val_acc = acc
                sum_acc = 0.
                s = time.time()
                for j, data in enumerate(test_loader):
                    input, target = data
                    input, target = input.to(device).float(), target.to(device).long()
                    output = model(input)
                    pred_label = output.max(1)[1]
                    sum_acc += (pred_label == target).float().sum()
                    if ((j+1) % 10) == 0:
                        d_p_s = 100/(time.time()-s)
                        print("[{}/{}] time : {:.1f} data/sec, rest : {:.2f} sec".format(
                            j+1, len(test_loader), d_p_s, (len(test_loader) - j-1)/d_p_s
                        ), "\r", end="")
                        s = time.time()
                print()
                test_acc = sum_acc / float(len(test_dataset))
                print("test accuracy : {}".format(test_acc))
                # torch.save(model.state_dict(), os.path.join(args.output, "best_model.pth"))
        model.train()
        s = time.time()
    # lr decay
    if iteration == shared_cfg["lr_decay_iter"]:
        optimizer.param_groups[0]["lr"] *= shared_cfg["lr_decay_factor"]

print("test acc : {}".format(test_acc))
condition["test_acc"] = test_acc.item()

exp_name += str(int(time.time())) # unique ID
if not os.path.exists(args.output):
    os.mkdir(args.output)
with open(os.path.join(args.output, exp_name + ".json"), "w") as f:
    json.dump(condition, f)


iteration [1/8800] cls loss : 1.194580e+00, SSL loss : 6.223509e-06, coef : 5.44450e-02, time : 724.812 iter/sec, rest : 0.202 min, lr : 0.0004 
iteration [1000/8800] cls loss : 4.593794e-02, SSL loss : 5.544227e-02, coef : 8.00000e+00, time : 0.759 iter/sec, rest : 171.335 min, lr : 8e-05 
iteration [2000/8800] cls loss : 1.836864e-02, SSL loss : 4.573943e-02, coef : 8.00000e+00, time : 0.758 iter/sec, rest : 149.500 min, lr : 8e-05 
iteration [3000/8800] cls loss : 1.258721e-02, SSL loss : 5.397944e-02, coef : 8.00000e+00, time : 0.758 iter/sec, rest : 127.467 min, lr : 8e-05 
iteration [4000/8800] cls loss : 9.231715e-03, SSL loss : 5.635411e-02, coef : 8.00000e+00, time : 0.758 iter/sec, rest : 105.551 min, lr : 8e-05 
iteration [5000/8800] cls loss : 6.741161e-03, SSL loss : 5.954162e-02, coef : 8.00000e+00, time : 0.758 iter/sec, rest : 83.559 min, lr : 8e-05 
iteration [6000/8800] cls loss : 5.356046e-03, SSL loss : 2.431375e-02, coef : 8.00000e+00, time : 0.760 iter/sec, rest 

In [None]:
* 결과해석:
    Mean Teacher모델을 CIFAR-10의 50000개의 데이터를 256의 배치 사이즈로 8800 iteration 훈련시킨 결과 
    "supervised loss : 0.005", "unsupervised loss: 0.02"로 수렴함
    
    테스트 결과는 0.5929임

### Virtual Adversarial Training(VAT)

In [10]:
args.alg = 'VAT'

In [11]:
if args.alg != "supervised":
    # batch size = 0.5 x batch size
    l_loader = DataLoader(
        l_train_dataset, shared_cfg["batch_size"]//2, drop_last=True,
        sampler=RandomSampler(len(l_train_dataset), shared_cfg["iteration"] * shared_cfg["batch_size"]//2)
    )
else:
    l_loader = DataLoader(
        l_train_dataset, shared_cfg["batch_size"], drop_last=True,
        sampler=RandomSampler(len(l_train_dataset), shared_cfg["iteration"] * shared_cfg["batch_size"])
    )
print("algorithm : {}".format(args.alg))
condition["algorithm"] = args.alg
exp_name += str(args.alg) + "_"

u_loader = DataLoader(
    u_train_dataset, shared_cfg["batch_size"]//2, drop_last=True,
    sampler=RandomSampler(len(u_train_dataset), shared_cfg["iteration"] * shared_cfg["batch_size"]//2)
)

val_loader = DataLoader(val_dataset, 128, shuffle=False, drop_last=False)
test_loader = DataLoader(test_dataset, 128, shuffle=False, drop_last=False)

print("maximum iteration : {}".format(min(len(l_loader), len(u_loader))))

alg_cfg = config[args.alg]
print("parameters : ", alg_cfg)
condition["h_parameters"] = alg_cfg

if args.em > 0:
    print("entropy minimization : {}".format(args.em))
    exp_name += "em_"
condition["entropy_maximization"] = args.em

model = wrn.WRN(2, dataset_cfg["num_classes"], transform_fn).to(device)
optimizer = optim.Adam(model.parameters(), lr=alg_cfg["lr"])

trainable_paramters = sum([p.data.nelement() for p in model.parameters()])
print("trainable parameters : {}".format(trainable_paramters))

algorithm : VAT
maximum iteration : 8800
parameters :  {'xi': 1e-06, 'eps': {'cifar10': 6, 'svhn': 1}, 'consis_coef': 0.3, 'lr': 0.003}
trainable parameters : 1467610


In [12]:
if args.alg == "VAT": # virtual adversarial training
    from vat import VAT
    ssl_obj = VAT(alg_cfg["eps"][args.dataset], alg_cfg["xi"], 1)
elif args.alg == "MT": # mean teacher
    from mean_teacher import MT
    t_model = wrn.WRN(2, dataset_cfg["num_classes"], transform_fn).to(device)
    t_model.load_state_dict(model.state_dict())
    ssl_obj = MT(t_model, alg_cfg["ema_factor"])
elif args.alg == "PI": # PI Model
    from pimodel import PiModel
    ssl_obj = PiModel()
else:
    raise ValueError("{} is unknown algorithm".format(args.alg))

print()
iteration = 0
maximum_val_acc = 0
s = time.time()
for l_data, u_data in zip(l_loader, u_loader):
    iteration += 1
    l_input, target = l_data
    l_input, target = l_input.to(device).float(), target.to(device).long()

    if args.alg != "supervised": # for ssl algorithm
        u_input, dummy_target = u_data
        u_input, dummy_target = u_input.to(device).float(), dummy_target.to(device).long()

        target = torch.cat([target, dummy_target], 0)
        unlabeled_mask = (target == -1).float()

        inputs = torch.cat([l_input, u_input], 0)
        outputs = model(inputs)

        # ramp up exp(-5(1 - t)^2)
        coef = alg_cfg["consis_coef"] * math.exp(-5 * (1 - min(iteration/shared_cfg["warmup"], 1))**2)
        ssl_loss = ssl_obj(inputs, outputs.detach(), model, unlabeled_mask) * coef

    else:
        outputs = model(l_input)
        coef = 0
        ssl_loss = torch.zeros(1).to(device)

    # supervised loss
    cls_loss = F.cross_entropy(outputs, target, reduction="none", ignore_index=-1).mean()

    loss = cls_loss + ssl_loss

    if args.em > 0:
        loss -= args.em * ((outputs.softmax(1) * F.log_softmax(outputs, 1)).sum(1) * unlabeled_mask).mean()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if args.alg == "MT" or args.alg == "ICT":
        # parameter update with exponential moving average
        ssl_obj.moving_average(model.parameters())
    # display
    if iteration == 1 or (iteration % 1000) == 0:
        wasted_time = time.time() - s
        rest = (shared_cfg["iteration"] - iteration)/100 * wasted_time / 60
        print("iteration [{}/{}] cls loss : {:.6e}, SSL loss : {:.6e}, coef : {:.5e}, time : {:.3f} iter/sec, rest : {:.3f} min, lr : {}".format(
            iteration, shared_cfg["iteration"], cls_loss.item(), ssl_loss.item(), coef, 100 / wasted_time, rest, optimizer.param_groups[0]["lr"]),
            "\r")
        s = time.time()

    # validation
    if (iteration % args.validation) == 0 or iteration == shared_cfg["iteration"]:
        with torch.no_grad():
            model.eval()
            print()
            print("### validation ###")
            sum_acc = 0.
            s = time.time()
            for j, data in enumerate(val_loader):
                input, target = data
                input, target = input.to(device).float(), target.to(device).long()

                output = model(input)

                pred_label = output.max(1)[1]
                sum_acc += (pred_label == target).float().sum()
                if ((j+1) % 10) == 0:
                    d_p_s = 10/(time.time()-s)
                    print("[{}/{}] time : {:.1f} data/sec, rest : {:.2f} sec".format(
                        j+1, len(val_loader), d_p_s, (len(val_loader) - j-1)/d_p_s
                    ), "\r", end="")
                    s = time.time()
            acc = sum_acc/float(len(val_dataset))
            print()
            print("varidation accuracy : {}".format(acc))
            # test
            if maximum_val_acc < acc:
                print("### test ###")
                maximum_val_acc = acc
                sum_acc = 0.
                s = time.time()
                for j, data in enumerate(test_loader):
                    input, target = data
                    input, target = input.to(device).float(), target.to(device).long()
                    output = model(input)
                    pred_label = output.max(1)[1]
                    sum_acc += (pred_label == target).float().sum()
                    if ((j+1) % 10) == 0:
                        d_p_s = 100/(time.time()-s)
                        print("[{}/{}] time : {:.1f} data/sec, rest : {:.2f} sec".format(
                            j+1, len(test_loader), d_p_s, (len(test_loader) - j-1)/d_p_s
                        ), "\r", end="")
                        s = time.time()
                print()
                test_acc = sum_acc / float(len(test_dataset))
                print("test accuracy : {}".format(test_acc))
                # torch.save(model.state_dict(), os.path.join(args.output, "best_model.pth"))
        model.train()
        s = time.time()
    # lr decay
    if iteration == shared_cfg["lr_decay_iter"]:
        optimizer.param_groups[0]["lr"] *= shared_cfg["lr_decay_factor"]

print("test acc : {}".format(test_acc))
condition["test_acc"] = test_acc.item()

exp_name += str(int(time.time())) # unique ID
if not os.path.exists(args.output):
    os.mkdir(args.output)
with open(os.path.join(args.output, exp_name + ".json"), "w") as f:
    json.dump(condition, f)


iteration [1/8800] cls loss : 1.210762e+00, SSL loss : 8.447826e-06, coef : 2.04169e-03, time : 2291.143 iter/sec, rest : 0.064 min, lr : 0.003 
iteration [1000/8800] cls loss : 4.877531e-02, SSL loss : 3.387230e-01, coef : 3.00000e-01, time : 0.596 iter/sec, rest : 218.078 min, lr : 0.0006000000000000001 
iteration [2000/8800] cls loss : 1.555471e-02, SSL loss : 2.539856e-01, coef : 3.00000e-01, time : 0.596 iter/sec, rest : 190.102 min, lr : 0.0006000000000000001 
iteration [3000/8800] cls loss : 2.724327e-02, SSL loss : 2.493916e-01, coef : 3.00000e-01, time : 0.596 iter/sec, rest : 162.275 min, lr : 0.0006000000000000001 
iteration [4000/8800] cls loss : 1.993705e-02, SSL loss : 2.151401e-01, coef : 3.00000e-01, time : 0.595 iter/sec, rest : 134.360 min, lr : 0.0006000000000000001 
iteration [5000/8800] cls loss : 7.853318e-03, SSL loss : 2.104976e-01, coef : 3.00000e-01, time : 0.596 iter/sec, rest : 106.228 min, lr : 0.0006000000000000001 
iteration [6000/8800] cls loss : 1.1173

In [None]:
* 결과해석:
    VAT모델을 CIFAR-10의 50000개의 데이터를 256의 배치 사이즈로 8800 iteration 훈련시킨 결과 
    "supervised loss : 0.005", "unsupervised loss: 0.2"로 수렴함
    
    테스트 결과는 0.65임

In [None]:
* 전체 결과해석:
    먼저, CIFAR-10 데이터셋을 활용하여 동일한 파라미터로 실험 후, TEST한 정확도 결과값은 아래와 같이,
    VAT, MT, Pi모델 순이었다.
    
    1. VAT: 65.07% (0.597 iter/sec)
    2. Mean Teacher: 59.29% (0.759 iter/sec)
    3. Pi-Model: 59.14% (0.886 iter/sec)
        
    3모델의 trainable parameter는 1467610로 고정하였으므로, 작은 노이즈에 취약하지 않은 강건한 모델인
    VAT의 성능이 가장 높은 것을 볼 수 있다.
    
    teacher와 student를 분리하여 학습한 mean teacher는 속도와 성능 면에서 pi-model에 비해 증가하였으나, 
    VAT처럼 큰 변화는 없었다. 일관성 제약의 접근을 고려하였을 때, 이미지들의 분류성능을 가장 높일 수 있는 
    준지도 학습 모델은 VAT인 것을 속도와 성능면에서 모두 확인 할 수 있었다.

## MNIST 데이터셋 활용하여 비교

## 1.Temporal Ensemble(2016)

In [None]:
* Temporal Ensemble(2016): 파이모델의 한계점이 ‘single network＇이었기 때문에,
    Multiple previous network evaluation의 예측 값들을 앙상블 prediction으로 취합함
    
    Teacher 모델의 Output이 불안정(noisy)하므로, EMA로 누적해 안정성을 높임
    
    (단점) Epoch마다 데이터 Z를 보관할 용량이 필요함 <-- 누적된 벡터값이 Z에 저장

In [None]:
* 출처: https://github.com/ferretj/temporal-ensembling

In [18]:
import torch
import torch.utils as utils
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets

import torchvision
from torchvision import transforms

import numpy as np
import time
import os, sys

In [19]:
def mnist_dataset(root, transform):

    # load train data
    train_dataset = datasets.MNIST(
        root=root,
        train=True,
        transform=transform,
        download=True)

    # load test data
    test_dataset = datasets.MNIST(
        root=root,
        train=False,
        transform=transform, download=True)

    return train_dataset, test_dataset

def sample_train(train_dataset, test_dataset, batch_size, k, n_classes, seed, shuffle_train=False, return_idx=True):
    '''Randomly form unlabeled data in training dataset'''

    n = len(train_dataset)  # dataset size
    rrng = np.random.RandomState(seed) # seed 
    indices = torch.zeros(k)  # indices of keep labeled data
    others = torch.zeros(n - k)  # indices of unlabeled data
    card = k // n_classes
    cpt = 0

    for i in range(n_classes):
        class_items = (train_dataset.train_labels == i).nonzero()  # indices of samples with label i
        n_class = len(class_items)  # number of samples with label i
        rd = rrng.permutation(np.arange(n_class))  # shuffle them
        indices[i * card: (i+1) * card] = torch.squeeze(class_items[rd[:card]])
        others[cpt: cpt+n_class-card] = torch.squeeze(class_items[rd[card:]])
        cpt += (n_class-card)

    # tensor as indices must be long, byte or bool
    others = others.long()
    train_dataset.train_labels[others] = -1

    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=batch_size,
                                               num_workers=2,
                                               shuffle=shuffle_train)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=batch_size,
                                              num_workers=2,
                                              shuffle=False)

    if return_idx:
        return train_loader, test_loader, indices
    return train_loader, test_loader

In [20]:
class GaussianNoise(nn.Module):
    """데이터에 noise 추가"""
    def __init__(self, batch_size, input_shape, std):
        super(GaussianNoise, self).__init__()
        self.shape = (batch_size, ) + input_shape
        self.std = std
        self.noise = torch.zeros(self.shape).cuda()

    def forward(self, x):
        self.noise.normal_(mean=0, std=self.std)
        # print(self.noise.shape)

        return x + self.noise

    def temporal_losses(out1, out2, w, labels):
        # output1: current output
        # output2: temporal output
        # w: weight for summation loss

        "ensemble output과 current output을 통해 supervised, unsupervised loss 및 total loss를 계산함"

        sup_loss, nbsup = GaussianNoise.masked_crossentropy(out1, labels)
        unsup_loss = GaussianNoise.mse_loss(out1, out2)
        total_loss = sup_loss + w * unsup_loss

        return total_loss, sup_loss, unsup_loss, nbsup

    def mse_loss(out1, out2):
        "current output, ensemble output 간의 mean difference: unsupervised loss"
        quad_diff = torch.sum((F.softmax(out1, dim=1) - F.softmax(out2, dim=1)) ** 2)

        return quad_diff / out1.data.nelement()

    def masked_crossentropy(out, labels):
        "labeld된 data에 한해서 crossentropy loss를 계산함"
        cond = (labels >= 0)
        nnz = torch.nonzero(cond)  # array of labeled sample index
        nbsup = len(nnz)  # number of supervised samples
        # check if labeled samples in batch, return 0 if none
        if nbsup > 0:
            # select lines in out with label
            masked_outputs = torch.index_select(out, 0, nnz.view(nbsup))
            masked_labels = labels[cond]
            loss = F.cross_entropy(masked_outputs, masked_labels)
            return loss, nbsup
        loss = torch.tensor([0.], requires_grad=False).cuda()
        return loss, 0

    def weight_scheduler(epoch, max_epochs, max_val, mult, n_labeled, n_samples):
        "epoch이 지남에 따라 weight를 조정함"
        max_val = max_val * (float(n_labeled) / n_samples)
        return GaussianNoise.ramp_up(epoch, max_epochs, max_val, mult)

    def ramp_up(epoch, max_epochs, max_val, mult):
        "weight를 조정하며 첫 epoch에는 0을 사용함"
        if epoch == 0:
            return 0.
        elif epoch >= max_epochs:
            return max_val
        return max_val * np.exp(-mult * (1. - float(epoch) / max_epochs) ** 2)
    
    def calc_metrics(model, loader):
        correct = 0
        total = 0
        for i, (samples, labels) in enumerate(loader):
            samples = samples.cuda()
            labels = labels.requires_grad_(False).cuda()
            outputs = model(samples)
            _, predicted = torch.max(outputs.detach(), 1)
            total += labels.size(0)
            correct += (predicted == labels.detach().view_as(predicted)).sum()
        acc = 100 * float(correct) / total
        return acc

In [21]:
class CNN(nn.Module):
    def __init__(self, batch_size, std, input_shape=(1, 28, 28), p=0.5, fm1=16, fm2=32):
        super(CNN, self).__init__()
        self.std = std
        self.p = p
        self.fm1 = fm1
        self.fm2 = fm2
        self.input_shape = input_shape
        self.conv_block1 = nn.Sequential(nn.Conv2d(1, self.fm1, 3, stride=1, padding=1),
                                        nn.BatchNorm2d(self.fm1), 
                                        nn.ReLU(),
                                        nn.MaxPool2d(3, stride=2, padding=1)
                                      )
        
        self.conv_block2 = nn.Sequential(nn.Conv2d(self.fm1, self.fm2, 3, stride=1, padding=1),
                                        nn.BatchNorm2d(self.fm2), 
                                        nn.ReLU(),
                                        nn.MaxPool2d(3, stride=2, padding=1)
                                      )
        self.drop = nn.Dropout(self.p)
        self.fc = nn.Linear(self.fm2 * 7 * 7, 10)


    def forward(self, x):
        if self.training:
            b = x.size(0)
            gn = GaussianNoise(b, self.input_shape, self.std)
            x = gn(x)

        # first block
        x = self.conv_block1(x)
        
        # second block
        x = self.conv_block2(x)

        # classifier
        x = x.view(-1, self.fm2 * 7 * 7)
        x = self.fc(self.drop(x))

        return x

In [41]:
## 학습

def train(model, train_loader, val_loader ,seed, k, alpha, lr, num_epochs, batch_size, ntrain,n_classes=10, max_epochs=80, max_val=1.):

    # build model and feed to GPU
    model.cuda()

    # setup param optimization
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.99))

    # model.train()
    
    # 첫 ensemble ouput은 모두 0
    Z = torch.zeros(ntrain, n_classes).float().cuda()  # intermediate values
    z = torch.zeros(ntrain, n_classes).float().cuda()  # temporal outputs
    outputs = torch.zeros(ntrain, n_classes).float().cuda()  # current outputs

    losses = []
    suplosses = []
    unsuplosses = []
    best_loss = 30.0
    for epoch in range(num_epochs):
        start_t = time.time()
        if epoch==0 or epoch%10==0 or epoch==49:
            print('\nEpoch: {}'.format(epoch+1))
        model.train()
        # evaluate unsupervised cost weight
        w = GaussianNoise.weight_scheduler(epoch, max_epochs, max_val, 5, k, 60000)

        w = torch.tensor(w, requires_grad=False).cuda()
        if epoch==0 or epoch%10==0 or epoch==49:
            print('---------------------')

        # targets change only once per epoch
        for i, (images, labels) in enumerate(train_loader):
            #print(i)
            batch_size = images.size(0)  # retrieve batch size again cause drop last is false
            images = images.cuda()
            labels = labels.requires_grad_(False).cuda()

            optimizer.zero_grad()
            out = model(images)
            # 현재 batch에 맞는 ensemble 결과들을 가져옴
            zcomp = z[i * batch_size: (i+1) * batch_size]
            zcomp.requires_grad_(False)
            loss, suploss, unsuploss, nbsup = GaussianNoise.temporal_losses(out, zcomp, w, labels)

            # save outputs
            outputs[i * batch_size: (i+1) * batch_size] = out.clone().detach()
            losses.append(loss.item())
            suplosses.append(nbsup * suploss.item())
            unsuplosses.append(unsuploss.item())

            # backprop
            loss.backward()
            optimizer.step()

        loss_mean = np.mean(losses)
        supl_mean = np.mean(suplosses)
        unsupl_mean = np.mean(unsuplosses)
        if epoch==0 or epoch%10==0 or epoch==49:
            print('Epoch [%d/%d], Loss: %.6f, Supervised Loss: %.6f, Unsupervised Loss: %.6f, Time: %.2f' % 
                  (epoch + 1, num_epochs, float(loss_mean), float(supl_mean), float(unsupl_mean), time.time()-start_t))

        # model의 outputs을 가중평균을 이용해 ensemble outputs으로 update 함
        Z = alpha * Z + (1. - alpha) * outputs
        z = Z * (1. / (1. - alpha ** (epoch + 1)))

        if loss_mean < best_loss:
            best_loss = loss_mean
            torch.save({'state_dict': model.state_dict()}, 'model_best.pth')

        model.eval()
        acc = GaussianNoise.calc_metrics(model, val_loader)
        if epoch==0 or epoch%10==0 or epoch==49:
            print('Acc : %.2f' % acc)

def evaluation(model, loader):

    # test best model
    checkpoint = torch.load('model_best.pth')
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()
    correct = 0
    total = 0
    for i, (samples, labels) in enumerate(loader):
        samples = samples.cuda()
        labels = labels.requires_grad_(False).cuda()
        outputs = model(samples)
        _, predicted = torch.max(outputs.detach(), 1)
        total += labels.size(0)
        correct += (predicted == labels.detach().view_as(predicted)).sum()
    acc = 100 * float(correct) / total
    print('Acc (best model): %.2f' % acc)

In [42]:
# global vars
n_exp = 1 # number of experiments, try 5 different seed
k = 100 # keep k labeled data in whole training set, other without label

# dataset vars
m = 0.1307
s = 0.3081

# model vars
drop = 0.5 # dropout probability
std = 0.15 # std of gaussian noise
fm1 = 32 # channels of the first conv
fm2 = 64 # channels of the second conv
w_norm = True

# optim vars
learning_rate = 0.002
beta2 = 0.99 # second momentum for Adam
num_epochs = 50
batch_size = 64

# temporal ensembling vars
alpha = 0.6 # ensembling momentum
data_norm = 'channelwise' # image normalization
divide_by_bs = False # whether we divide supervised cost by batch_size

# RNG
rng = np.random.RandomState(42)
seeds = [rng.randint(200) for _ in range(n_exp)]

In [43]:
# cfg = vars(config)

# prepare data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(m,s)])
train_dataset, val_dataset = mnist_dataset(root='~/datasets/MNIST', transform=transform)
ntrain = len(train_dataset)


for i in range(n_exp):
    model = CNN(batch_size, std, fm1=fm1, fm2=fm2).cuda()
    seed = seeds[i]
    train_loader, val_loader, indices = sample_train(train_dataset, val_dataset, batch_size=batch_size,
                                                 k=k, n_classes=10, seed=seed, shuffle_train=False)
    train(model, train_loader, val_loader,seed, k, alpha, learning_rate,
         num_epochs, batch_size, ntrain)
    evaluation(model, val_loader)


Epoch: 1
---------------------
Epoch [1/50], Loss: 0.761350, Supervised Loss: 0.790792, Unsupervised Loss: 0.065815, Time: 4.39
Acc : 31.96

Epoch: 11
---------------------
Epoch [11/50], Loss: 0.144925, Supervised Loss: 0.149759, Unsupervised Loss: 0.037651, Time: 4.31
Acc : 86.34

Epoch: 21
---------------------
Epoch [21/50], Loss: 0.080538, Supervised Loss: 0.083150, Unsupervised Loss: 0.027056, Time: 4.32
Acc : 92.20

Epoch: 31
---------------------
Epoch [31/50], Loss: 0.054949, Supervised Loss: 0.056718, Unsupervised Loss: 0.021271, Time: 4.26
Acc : 94.06

Epoch: 41
---------------------
Epoch [41/50], Loss: 0.042135, Supervised Loss: 0.043472, Unsupervised Loss: 0.017923, Time: 4.33
Acc : 94.83

Epoch: 50
---------------------
Epoch [50/50], Loss: 0.035941, Supervised Loss: 0.037041, Unsupervised Loss: 0.016183, Time: 4.50
Acc : 95.20
Acc (best model): 95.20


In [None]:
*결과해석: Temporal Ensembling은 앞서 언급한 것처럼, 
    파이모델처럼 single network가 아닌 mulitple network evaluation의 예측 값들을 앙상블하여 예측값으로 사용함
    
    따라서, 파이모델에 비해 좀 더 robust한 값이 결과값으로 나올 수 있게 하였다.
    
    Temporal Ensembling을 MNIST의 60000개의 데이터를 64의 배치 사이즈로 50epoch 훈련시킨 결과,
    "supervised loss : 0.037041", "unsupervised loss: 0.035941"로 수렴함
    
    테스트 결과는 0.9520임

## 2. Mean Teacher

In [None]:
* 출처: https://github.com/shenkev/Pytorch-Mean-Teacher

In [47]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4 * 4 * 50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4 * 4 * 50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


def train(args, model, mean_teacher, device, train_loader, test_loader, optimizer, epoch):
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()

        output = model(data)

        ########################### CODE CHANGE HERE ######################################
        # forward pass with mean teacher
        # torch.no_grad() prevents gradients from being passed into mean teacher model
        with torch.no_grad():
            mean_t_output = mean_teacher(data)

        ########################### CODE CHANGE HERE ######################################
        # consistency loss (example with MSE, you can change)
        const_loss = F.mse_loss(output, mean_t_output)

        ########################### CODE CHANGE HERE ######################################
        # set the consistency weight (should schedule)
        weight = 0.2
        loss = F.nll_loss(output, target) + weight*const_loss
        unsupervised_loss = weight*const_loss
        supervised_loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()

        ########################### CODE CHANGE HERE ######################################
        # update mean teacher, (should choose alpha somehow)
        # Use the true average until the exponential average is more correct
        alpha = 0.95
        for mean_param, param in zip(mean_teacher.parameters(), model.parameters()):
            mean_param.data.mul_(alpha).add_(1 - alpha, param.data)

        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tUn_loss: {:.6f}\tS_loss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item(), unsupervised_loss.item(), supervised_loss.item()))


def test(args, model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    model.train()


def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=64, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=50, metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=60000, metavar='N',
                        help='how many batches to wait before logging training status')

    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')
#     args = parser.parse_args()
    args, _ = parser.parse_known_args()
    
    use_cuda = torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=args.batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])),
        batch_size=args.test_batch_size, shuffle=True, **kwargs)

    model = Net().to(device)
    ########################### CODE CHANGE HERE ######################################
    # initialize mean teacher
    mean_teacher = Net().to(device)
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

    for epoch in range(1, args.epochs + 1):
        train(args, model, mean_teacher, device, train_loader, test_loader, optimizer, epoch)

#     test(args, model, device, test_loader)
    test(args, mean_teacher, device, test_loader)

    if (args.save_model):
        torch.save(model.state_dict(), "mnist_mean_teacher.pt")


if __name__ == '__main__':
    main()


Test set: Average loss: 0.0189, Accuracy: 9938/10000 (99%)



In [None]:
* 결과해석:
    Mean Teacher모델을 MNIST의 60000개의 데이터를 64의 배치 사이즈로 50epoch 훈련시킨 결과,
    "supervised loss : 0.0016", "unsupervised loss: 0.002"로 수렴함
    
    테스트 결과는 0.9938임.

In [None]:
* 전체 결과해석:
    먼저, MNIST 데이터셋을 활용하여 동일한 파라미터로 실험 후, TEST한 정확도 결과값은 아래와 같이,
    MT, Temporal Ensembling 모델 순이었다.
    
    1. Mean Teacher: 99.38%
    2. Temporal Ensembling: 95.20%
        
    Temporal Ensemble에선 output이 불안정하여 EMA(Exponential moving average)로 누적하여 안정성을 높인 것을 
    택하였지만, mean teacher에서는 teacher와 student를 각각 지정해 'student의 가중치를 EMA하여 teacher에 사용'하였다.
    
    결과에서도 볼 수 있듯이, Temporal Ensembling의 주요 기법인 output의 평균값을 적용하는 것보다,
    Mean teacher처럼, teacher와 student를 지정하여서 학습하게 하는 것이 메모리의 부담도 적고 속도와 성능면에서
    뛰어난 것을 알 수 있었다.