In [None]:
import gc
import os
import sys
sys.path.append('../')
import time
from collections import namedtuple
from tqdm import tqdm
import numpy as np
import torch
import torchvision
import src.utils as utils
from src.models import LeNetMNIST, ResVestimatorMNIST, R18VestimatorMNIST, Vestimator, LeVestimatorMNIST
from src.frameworks.online_dvrl import Odvrl
from src.frameworks.online_proposed import Proposed

In [None]:
from collections import namedtuple
Parameters = namedtuple('Parameters', [
    'saving_path',
    'val_batch_size',
    'epochs',
    'device',
    'learning_rate',
    'num_workers',
    'explore_strategy',
    'epsilon0',
    'threshold',
    # 'input_dim',
    # 'hidden_dim',
    # 'output_dim',
    # 'layer_number',
])

In [None]:
parameters = Parameters(
    saving_path='../logs',
    val_batch_size=128,
    epochs=15,
    device='cuda',
    learning_rate=1e-6,
    num_workers=1,
    explore_strategy='constant',  # ['linear', 'constant', 'exponential']
    epsilon0=0.5,
    threshold=0.5,
    # input_dim=794,
    # hidden_dim=100,
    # output_dim=10,
    # layer_number=5,
)

In [None]:
T = 10
noise_level = 0.2
seed = 3407
num_weak = 100
pred_model = LeNetMNIST()
val_model = LeNetMNIST()
value_estimator = LeVestimatorMNIST()

In [None]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model.classify(data)
        loss = torch.nn.functional.nll_loss(output, target)
        loss.backward()
        optimizer.step()

In [None]:
transform=torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.1307,), (0.3081,))
])

In [None]:
if os.path.isfile('../data/pretrained_models/pretrained_levestimator.pt'):
    value_estimator.load_state_dict(
        torch.load('../data/pretrained_models/pretrained_levestimator.pt')
    )
else:
    train_dataset = torchvision.datasets.MNIST('../data', train=True, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
    optimizer = torch.optim.Adadelta(value_estimator.parameters(), lr=1.0)
    value_estimator.to('cuda')
    for epoch in tqdm(range(14)):
        train(value_estimator, 'cuda', train_loader, optimizer, epoch)
    torch.save(
        value_estimator.state_dict(),
        '../data/pretrained_models/pretrained_levestimator.pt'
    )

In [None]:
np.random.seed(seed)
torch.manual_seed(seed)

In [None]:
# run 'pretrain_mnist.ipynb' first
# then load this pretrained model
state_dict = torch.load("../data/pretrained_models/pretrained_lenetmnist.pt")
val_model.load_state_dict(state_dict)

In [None]:
x_train, y_train, x_test, y_test, noisy_idxs = utils.create_noisy_mnist(method='uniform', noise_level=noise_level)

In [None]:
noisy_idxs.sort()
noisy_idxs[:20]

In [None]:
x_train = torch.tensor(x_train)
y_train = torch.tensor(y_train)
test_data = torchvision.datasets.MNIST('../data', train=False, download=True, transform=transform)

In [None]:
num_data = len(y_train)

In [None]:
engine = Proposed(num_weak=num_weak, pred_model=pred_model, val_model=val_model, value_estimator=value_estimator, parameters=parameters)

In [None]:
time.asctime()

In [None]:
subset_len = num_data // T
for t in range(T):
    start_id = t * subset_len
    end_id = min((t + 1) * subset_len, num_data)
    engine.one_step(
        t, 
        X=x_train[start_id:end_id], 
        y=y_train[start_id:end_id], 
        val_dataset=test_data, 
    )
    utils.super_save()
    current_noisy_idxs = np.extract((noisy_idxs >= start_id) & (noisy_idxs < end_id), noisy_idxs)
    current_corrupted_num = len(current_noisy_idxs)
    if current_corrupted_num > 0:
        values = []
        for i in range(subset_len // 128):
            part_values = engine.evaluate(x_train[start_id+i*128: start_id+min((i+1)*128, num_data)], y_train[start_id+i*128: start_id+min((i+1)*128, num_data)])
            values = np.concatenate((values, part_values))
            utils.super_save()
        guess_idxs = np.argsort(values)
        print(guess_idxs[:10])
        print(current_noisy_idxs[:10])
        discover_rate = len(np.intersect1d(start_id+guess_idxs[:current_corrupted_num], current_noisy_idxs)) / current_corrupted_num
        print('discover rate: {}'.format(discover_rate))

In [None]:
time.asctime()