In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor
from collections import defaultdict
from tqdm import tqdm
from torch.utils.data import Subset
import numpy as np
import torch.nn.functional as F
from datasets import GoldCorrectionDataset
from glc import CorrectionGenerator, GoldCorrectionLossFunction

In [6]:
from torchvision.datasets import MNIST
trn_ds = MNIST('data/', train=True, transform=ToTensor())
val_ds = MNIST('data/', train=False, transform=ToTensor())

In [7]:
c_gen = CorrectionGenerator(dataset=trn_ds, randomization_strength=1.0)

100%|██████████| 6000/6000 [00:00<00:00, 6324.63it/s]


In [8]:
trusted_dataset, untrusted_dataset = c_gen.fetch_datasets()

In [12]:
from torched.customs.layers import LinearLayer, Flatten
from torched.trainer_utils import Train
class Net(nn.Module):
    def __init__(self, in_dims, hid_dims, out_dims):
        super(Net, self).__init__()
        self.net = nn.Sequential(Flatten(),
                                LinearLayer(in_dims, hid_dims, use_bn=True),
                                LinearLayer(hid_dims, hid_dims, use_bn=True),
                                nn.Linear(hid_dims, out_dims))
    def forward(self, x):
        if isinstance(x, list):
            inp, c = x[0], x[1]
        else:
            inp = x
        out = self.net(inp)
        if isinstance(x, list):
            return [out, c]
        return out

In [13]:
model = Net(784, 300, 10)
u_dl = DataLoader(untrusted_dataset, batch_size=32)
v_dl = DataLoader(val_ds, batch_size=32)
trainer = Train(model, [u_dl, v_dl], cuda=False)

In [14]:
trainer.train(1e-4, 3, 2, crit=nn.CrossEntropyLoss(), opt='adamW')

HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

Train Loss 2.328309: 100%|██████████| 1688/1688 [00:40<00:00, 35.15it/s]
Valid Loss 2.332816: 100%|██████████| 313/313 [00:02<00:00, 109.70it/s]
Train Loss 2.299558: 100%|██████████| 1688/1688 [00:42<00:00, 39.31it/s]
Valid Loss 2.330564: 100%|██████████| 313/313 [00:02<00:00, 117.04it/s]
Train Loss 2.262732: 100%|██████████| 1688/1688 [00:42<00:00, 39.47it/s]
Valid Loss 2.329315: 100%|██████████| 313/313 [00:02<00:00, 113.94it/s]
Train Loss 2.271581: 100%|██████████| 1688/1688 [00:43<00:00, 38.89it/s]
Valid Loss 2.338990: 100%|██████████| 313/313 [00:02<00:00, 119.94it/s]
Train Loss 2.245886: 100%|██████████| 1688/1688 [00:43<00:00, 35.17it/s]
Valid Loss 2.339327: 100%|██████████| 313/313 [00:02<00:00, 107.19it/s]
Train Loss 2.205070: 100%|██████████| 1688/1688 [00:43<00:00, 38.45it/s]
Valid Loss 2.340389: 100%|██████████| 313/313 [00:02<00:00, 109.98it/s]
Train Loss 2.172739: 100%|██████████| 1688/1688 [00:45<00:00, 37.32it/s]
Valid Loss 2.348561: 100%|██████████| 313/313 [00:02<00:0







In [15]:
label_correction_matrix = c_gen.generate_correction_matrix(trainer.model, 32)

Processing label 0: 100%|██████████| 19/19 [00:00<00:00, 114.33it/s]
Processing label 1: 100%|██████████| 20/20 [00:00<00:00, 122.72it/s]
Processing label 2: 100%|██████████| 20/20 [00:00<00:00, 189.87it/s]
Processing label 3: 100%|██████████| 18/18 [00:00<00:00, 177.34it/s]
Processing label 4: 100%|██████████| 19/19 [00:00<00:00, 175.65it/s]
Processing label 5: 100%|██████████| 18/18 [00:00<00:00, 182.80it/s]
Processing label 6: 100%|██████████| 19/19 [00:00<00:00, 161.49it/s]
Processing label 7: 100%|██████████| 20/20 [00:00<00:00, 155.87it/s]
Processing label 8: 100%|██████████| 18/18 [00:00<00:00, 168.99it/s]
Processing label 9: 100%|██████████| 22/22 [00:00<00:00, 125.26it/s]

Done





In [16]:
gold_ds = GoldCorrectionDataset(trusted_dataset, untrusted_dataset)
gold_dl = DataLoader(gold_ds, batch_size=32, shuffle=True)
g_val_ds = GoldCorrectionDataset(val_ds, val_ds, valid=True)
g_val_dl = DataLoader(g_val_ds, batch_size=32)
gold_loss = GoldCorrectionLossFunction(label_correction_matrix)

In [17]:
trainer.dataloader = gold_dl
trainer.val_dataloader = g_val_dl
trainer.train(1e-4, 3, 2, crit=gold_loss, opt='adamW')

HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

Train Loss 2.148798: 100%|██████████| 1875/1875 [00:48<00:00, 38.78it/s]
Valid Loss 0.439217: 100%|██████████| 625/625 [00:06<00:00, 95.55it/s] 
Train Loss 2.106702: 100%|██████████| 1875/1875 [00:51<00:00, 36.76it/s]
Valid Loss 0.236635: 100%|██████████| 625/625 [00:06<00:00, 98.85it/s] 
Train Loss 2.091937: 100%|██████████| 1875/1875 [00:52<00:00, 35.83it/s]
Valid Loss 0.220722: 100%|██████████| 625/625 [00:06<00:00, 99.50it/s] 
Train Loss 2.090299: 100%|██████████| 1875/1875 [00:51<00:00, 34.03it/s]
Valid Loss 0.183325: 100%|██████████| 625/625 [00:06<00:00, 103.53it/s]
Train Loss 2.082261: 100%|██████████| 1875/1875 [00:52<00:00, 35.38it/s]
Valid Loss 0.163479: 100%|██████████| 625/625 [00:06<00:00, 92.16it/s] 
Train Loss 2.077849: 100%|██████████| 1875/1875 [00:53<00:00, 34.91it/s]
Valid Loss 0.157068: 100%|██████████| 625/625 [00:05<00:00, 104.54it/s]
Train Loss 2.076205: 100%|██████████| 1875/1875 [00:54<00:00, 34.29it/s]
Valid Loss 0.155766: 100%|██████████| 625/625 [00:06<00:0




In [18]:
def epoch(loader, model, opt=None):
    """Standard training/evaluation epoch over the dataset"""
    total_loss, total_err = 0.,0.
    for X,y in loader:
        X,y = X, y
        yp = model(X)
        loss = nn.CrossEntropyLoss()(yp,y)
        if opt:
            opt.zero_grad()
            loss.backward()
            opt.step()
        
        total_err += (yp.max(dim=1)[1] != y).sum().item()
        total_loss += loss.item() * X.shape[0]
    return total_err / len(loader.dataset), total_loss / len(loader.dataset)

In [19]:
err, loss = epoch(v_dl, trainer.model)

In [20]:
1 - err

0.9545