In [2]:
import numpy as np
import struct
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as Data
import torchvision
from copy import deepcopy
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import time
from scipy import spatial
from util.util import mnist_noise

from trajectoryReweight.model import WeightedCrossEntropyLoss, TrajectoryReweightNN
from trajectoryReweight.baseline import StandardTrainingNN
from trajectoryReweight.gmm import GaussianMixture

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

def read_idx(filename):
    with open(filename, 'rb') as f:
        zero, data_type, dims = struct.unpack('>HBB', f.read(4))
        shape = tuple(struct.unpack('>I', f.read(4))[0] for d in range(dims))
        return np.frombuffer(f.read(), dtype=np.uint8).reshape(shape)

cuda:0


In [3]:
"""
LeNet
"""
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 2)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
    def name(self):
        return "LeNet"

def accuracy(predict_y, test_y):
    score = 0
    for pred, acc in zip(predict_y, test_y):
        if pred == acc:
            score +=1
    return score / test_y.shape[0]

In [None]:
"""
Imbalanced MNIST DATA
"""
x_train = read_idx('data/train-images.idx3-ubyte')
y_train = read_idx('data/train-labels.idx1-ubyte')
x_test = read_idx('data/t10k-images.idx3-ubyte')
y_test = read_idx('data/t10k-labels.idx1-ubyte')

four_index = y_train == 4
nine_index = y_train == 9

y_fours = y_train[four_index]-4
y_nines = y_train[nine_index]-8
x_fours = x_train[four_index]
x_nines = x_train[nine_index]

valid_four_index = np.random.choice(range(len(y_fours)), size=250, replace=False)
valid_nine_index = np.random.choice(range(len(y_nines)), size=250, replace=False)

x_valid_four =  x_fours[valid_four_index]
x_valid_nine =  x_nines[valid_nine_index]
y_valid_four =  y_fours[valid_four_index]
y_valid_nine =  y_nines[valid_nine_index]

x_valid = np.append(x_valid_four,x_valid_nine,0)
y_valid = np.append(y_valid_four,y_valid_nine)
indices = np.arange(x_valid.shape[0])
np.random.shuffle(indices)
x_valid = x_valid[indices]
y_valid = y_valid[indices]

x_train_four = np.delete(x_fours, valid_four_index, axis=0)
y_train_four = np.delete(y_fours, valid_four_index)
x_train_nine = np.delete(x_nines, valid_four_index, axis=0)
y_train_nine = np.delete(y_nines, valid_four_index)

four_index = y_test == 4
nine_index = y_test == 9

y_fours = y_test[four_index]-4
y_nines = y_test[nine_index]-8
x_fours = x_test[four_index]
x_nines = x_test[nine_index]
x_test = np.append(x_fours,x_nines,0)
y_test = np.append(y_fours,y_nines)
indices = np.arange(x_test.shape[0])
np.random.shuffle(indices)
x_test = x_test[indices]
y_test = y_test[indices]

In [None]:
"""
Imbalanced MNIST Study
""" 
num_minority = [500,250,150,50,25]
ratio = 25 # pick one from above

print('ratio: 4:9 = {}:{}'.format(ratio, 5000-ratio))
four_part = np.random.choice(range(len(x_train_four)), size=ratio, replace=False)
nine_part = np.random.choice(range(len(x_train_nine)), size=5000-ratio, replace=False)
x_train = np.append(x_train_four[four_part],x_train_nine[nine_part],0)
y_train = np.append(y_train_four[four_part],y_train_nine[nine_part])
indices = np.arange(x_train.shape[0])
np.random.shuffle(indices)
x_train = x_train[indices]
y_train = y_train[indices]


x_train = np.transpose(x_train,(2,1,0))
x_valid = np.transpose(x_valid,(2,1,0))
x_test = np.transpose(x_test,(2,1,0))
x_train_tensor = torchvision.transforms.ToTensor()(x_train).unsqueeze(1)
x_valid_tensor = torchvision.transforms.ToTensor()(x_valid).unsqueeze(1)
x_test_tensor = torchvision.transforms.ToTensor()(x_test).unsqueeze(1)
y_train_tensor = torch.from_numpy(y_train.astype(np.long))
y_valid_tensor = torch.from_numpy(y_valid.astype(np.long))
y_test_tensor = torch.from_numpy(y_test.astype(np.long))

In [None]:
"""
Imbalanced MNIST without reweight
"""
lenet = LeNet()
lenet.to(device)

stand_trainNN = StandardTrainingNN(lenet,
                                   batch_size=100,
                                   num_iter=80,
                                   learning_rate=1e-3,
                                   early_stopping=80,
                                   device=device,
                                   iprint=1)
stand_trainNN.fit(x_train_tensor, y_train_tensor, x_valid_tensor, y_valid_tensor, x_test_tensor, y_test_tensor)

test_output_y = stand_trainNN.predict(x_test_tensor)
test_accuracy = accuracy(test_output_y, y_test_tensor.data.numpy())

print('test accuracy is {}%'.format(100 * test_accuracy))

In [None]:
"""
Imbalanced MNIST with reweight
"""
lenet = LeNet()
lenet.to(device)

tra_weightNN = TrajectoryReweightNN(lenet,
                                    burnin=5,
                                    num_cluster=5,
                                    batch_size=100,
                                    num_iter=75,
                                    learning_rate=1e-3,
                                    early_stopping=20,
                                    device=device,
                                    iprint=2)
tra_weightNN.fit(x_train_tensor, y_train_tensor, x_valid_tensor, y_valid_tensor,x_test_tensor, y_test_tensor)

test_output_y = tra_weightNN.predict(x_test_tensor)
test_accuracy = accuracy(test_output_y, y_test_tensor.data.numpy())

print('test accuracy is {}%'.format(100 * test_accuracy))