<a href="https://colab.research.google.com/github/respect5716/deep-learning-paper-implementation/blob/main/01_General/NDF.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# NDF

## 0. Paper

### Info
* Title: Learning What Data to Learn
* Author: Yang Fan
* Task: Data Filtration
* URL: https://arxiv.org/abs/1702.08635


### Features
* Dataset: CIFAR-10

### Reference
* https://github.com/kuangliu/pytorch-cifar
* https://github.com/Finspire13/pytorch-policy-gradient-example

## 1. Setting

In [None]:
!pip install -q wandb
!pip install -q transformers
!pip install -q pytorch_lightning

In [1]:
import os
import wandb
from glob import glob
from tqdm.auto import tqdm
from typing import Optional

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms

import pytorch_lightning as pl

_ = pl.utilities.seed.seed_everything(999)

Global seed set to 999


In [2]:
import os
if os.path.isdir('ds_utils'):
    !rm -rf ds_utils
!git clone -q https://github.com/respect5716/ds_utils.git

from ds_utils.pytorch_lightning import BaseModule, BaseDataModule, CheckpointCallback, AttributeDict
from ds_utils.metric import accuracy

In [3]:
args = AttributeDict({
    'batch_size': 128,
    'episode_size': 100,
    'base_dir': '/content/drive/Shared drives/Yoon/Project/Doing/Deep Learning Paper Implementation',
    'gamma': 0.99,
})

## 2. Data

In [4]:
def load_dataloader(batch_size, transform):
    train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform['train'])
    train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [45000, 5000])
    test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform['eval'])

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size, shuffle=False)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size, shuffle=False)

    return train_loader, val_loader, test_loader

In [5]:
transform = {
    'train': transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2439, 0.2616)),
    ]),
    
    'eval': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2439, 0.2616)),
    ])
}

In [6]:
train_loader, val_loader, test_loader = load_dataloader(args.batch_size, transform)

Files already downloaded and verified
Files already downloaded and verified


In [7]:
inputs = next(iter(train_loader))
inputs[0].size(), inputs[1].size()

(torch.Size([128, 3, 32, 32]), torch.Size([128]))

## 3. Model

In [8]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])
    
def NDF():
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=1)

## 4. NDF

In [9]:
def check_device(model):
    return next(model.parameters()).device

def valid_accuracy(learner, val_loader):
    learner.eval()
    device = check_device(learner)
    corrects = []

    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            logits = learner(inputs)
            preds = logits.max(dim=1)[1]
            _corrects = (preds == targets).float()
            corrects.append(_corrects)

    acc = torch.cat(corrects).mean()
    return acc

In [10]:
ndf = NDF().cuda()
ndf_optim = torch.optim.Adam(ndf.parameters(), lr=0.001)

In [11]:
criteria = 0.80
iter_size = 10000
b = 0

In [None]:
for ep in tqdm(range(args.episode_size)):
    learner = ResNet18().cuda()
    device = check_device(learner)
    learner_optim = torch.optim.Adam(learner.parameters(), lr=0.01)
    learner_optim = torch.optim.SGD(learner.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(learner_optim, T_max=iter_size)

    rt = 0.
    reward_hist, state_hist, action_hist = [], [], []
    pbar = tqdm(range(iter_size))
    for t in pbar:
        learner.train()

        inputs, targets = next(iter(train_loader))
        inputs, targets = inputs.to(device), targets.to(device)

        states = inputs
        action_prob = torch.sigmoid(ndf(states)).squeeze()
        action = torch.distributions.Bernoulli(action_prob).sample()
        state_hist.append(states)
        action_hist.append(action)

        inputs, targets = inputs[action.bool()], targets[action.bool()]
        logits = learner(inputs)
        loss = F.cross_entropy(logits, targets)

        learner_optim.zero_grad()
        loss.backward()
        learner_optim.step()
        scheduler.step()

        rt = valid_accuracy(learner, val_loader)
        reward_hist.append(rt)
        pbar.set_postfix({'val_acc': float(rt)})
        if rt > criteria:
            break

    rl = -np.log(len(reward_hist) / iter_size)
    b = 0.8 * b + 0.2 * rl

    for i in range(len(state_hist)):
        state = state_hist[i]
        action = action_hist[i]

        prob = torch.sigmoid(ndf(state))
        loss = -torch.distributions.Bernoulli(prob).log_prob(action.unsqueeze(1)) * rl
        loss = loss.mean()
        loss.backward()
    
    ndf_optim.step()