In [None]:
import os
import math
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data_utils
from torchvision import datasets
from torchvision import transforms as T

%matplotlib inline
import network
from utils import make_z, make_y, make_fixed_z, make_fixed_y
from utils.drs import drs, fitting_capacity

In [None]:
nz, ny, nc = 100, 10, 1
dataset = datasets.MNIST('./dataset/mnist', train=True, download=True)
testset = datasets.MNIST('./dataset/mnist', train=False, download=True)
testset.transform = T.Compose([
    T.Resize(32),
    T.ToTensor(),
    T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

netG = network.ACGAN_MNIST_Generator(nz, nc, ny).cuda()
netD = network.ACGAN_MNIST_Discriminator(nc, ny, use_sn=True).cuda()
netG.eval()
netD.eval()

def load_network(name, it=1):
    netG_path = './{}/netG_{}.pth'.format(name, it)
    netD_path = './{}/netD_{}.pth'.format(name, it)
    netG.load_state_dict(torch.load(netG_path))
    netD.load_state_dict(torch.load(netD_path))

load_network('logs/reweight_base')

In [None]:
num_samples = 5000  # sample per class
z = make_z(num_samples * ny, nz).cuda()
y = make_fixed_y(num_samples, ny).cuda()
with torch.no_grad():
    x = netG(z, y)

samples = data_utils.TensorDataset(x.cpu(), y.cpu())
acc = fitting_capacity(samples, testset)
print(acc)

In [None]:
num_samples = 5000  # sample per class
x = drs(netG, netD, num_samples, perc=10)
samples = data_utils.TensorDataset(x.cpu(), y.cpu())
acc = fitting_capacity(samples, testset)
print(acc)