In [1]:
import numpy as np
import copy
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

In [2]:
import torch
import torch.nn as nn
import dtnnlib as dtnn
# import resnet_cifar

from torchvision import datasets, transforms as T
from torch.utils import data

In [3]:
from tqdm import tqdm
import os, time, sys, random
import json

In [4]:
mnist_transform = T.Compose([
    T.ToTensor(),
    T.Normalize(
        mean=[0.5,],
        std=[0.5,],
    ),
])

train_dataset = datasets.FashionMNIST(root="../../../../_Datasets/", train=True, download=True, transform=mnist_transform)
test_dataset = datasets.FashionMNIST(root="../../../../_Datasets/", train=False, download=True, transform=mnist_transform)

In [5]:
# train_dataset.data = train_dataset.data.view(-1, 28*28)
# test_dataset.data = test_dataset.data.view(-1, 28*28)

In [6]:
batch_size = 50
train_loader = data.DataLoader(dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)
test_loader = data.DataLoader(dataset=test_dataset, num_workers=4, batch_size=batch_size, shuffle=False)

In [7]:
device = torch.device("cuda:0")

In [8]:
criterion = nn.CrossEntropyLoss()

In [9]:
for xx, yy in train_loader:
    xx, yy = xx.to(device), yy.to(device)
    print(xx.shape, yy.shape)
    break



torch.Size([50, 1, 28, 28]) torch.Size([50])


## Any function as metric

In [10]:
class FunctionDT(nn.Module):
    
    def __init__(self, input_dim, num_centers, func, inv_temp=0.):
        '''
        func [input_dim -> 1]
        '''
        super().__init__()
        self.input_dim = input_dim
        self.num_centers = num_centers
        self.func = func
        
        self.inv_temp = nn.Parameter(torch.ones(1)*inv_temp)
        
        self.centers = torch.randn(num_centers, input_dim)/3.
        self.centers = nn.Parameter(self.centers)
    
    def forward(self, x):
        z = x.unsqueeze(1) - self.centers.unsqueeze(0)
        dists = self.func(z).squeeze(-1)
        dists = -dists*torch.exp(self.inv_temp)
        
        ## COmment out (un-normalized)
#         dists = dists-dists.mean(dim=1, keepdim=True)
#         dists = dists/dists.std(dim=1, keepdim=True)

        return dists

In [11]:
from classes import DistanceRegressor, ConvexNN
from nflib.flows import SequentialFlow, ActNorm
import nflib.res_flow as irf

## Merge all models into single and benchmark

In [12]:
models_keys = ["l_0.5", "l_1", "l_2", "l_20", "stereo", "linear",]
def get_models(h = 5, key='linear'):
    I = 784
    layer1 = None
    if key == "l_0.5":
        layer1 = dtnn.DistanceTransform(I, h, p=0.5, bias=False)
    elif key == "l_1":
        layer1 = dtnn.DistanceTransform(I, h, p=1, bias=False)
    elif key == "l_2":
        layer1 = dtnn.DistanceTransform(I, h, bias=False)
    elif key == "l_20":
        layer1 = dtnn.DistanceTransform(I, h, p=20, bias=False)
    elif key == "stereo":
        layer1 = dtnn.StereographicTransform(I, h, bias=False)
    elif key == "linear":
        layer1 = nn.Linear(I, h, bias=False)
    else:
        raise KeyError()
        
    net = nn.Sequential(
        layer1,
#         nn.BatchNorm1d(h),
        nn.LayerNorm(h),
        nn.ELU(),
        nn.Linear(h, 10),
        )
    return net

In [13]:
models_func_keys = ["convex", "invex", "ordinary"]

def get_models_func(h = 500, func_h=500, key='ordinary'):
#     I = 784
    layer1 = None
    if key == "convex":
        layer1 = ConvexNN([784, func_h, 1])
    elif key == "invex":
        layer1 = nn.Sequential(
                    ActNorm(784),
                    irf.ResidualFlow(784, [func_h], activation=irf.LeakyReLU),
                    ActNorm(784),
                    DistanceRegressor(784),
                    )
    elif key == "ordinary":
        layer1 = nn.Sequential(
                    ActNorm(784),
                    irf.ResidualFlow(784, [func_h], activation=irf.LeakyReLU),
                    ActNorm(784),
#                     DistanceRegressor(784),
                    nn.Linear(784, 1),
                    )
        irf.remove_spectral_norm_model(layer1)
    else:
        raise KeyError()
        
    net = nn.Sequential(
        FunctionDT(784, h, layer1),
        nn.BatchNorm1d(h),
#         nn.LayerNorm(h),
        nn.ELU(),
        nn.Linear(h, 10),
        )
    return net

In [14]:
## Following is copied from 
### https://github.com/kuangliu/pytorch-cifar/blob/master/main.py

# Training
def train(epoch, model, optimizer):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
#     for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader)):
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device).view(-1, 28*28), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
#     print(f"[Train] {epoch} Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    return

In [15]:
best_acc = -1
def test(epoch, model, model_name):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
#         for batch_idx, (inputs, targets) in enumerate(tqdm(test_loader)):
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(device).view(-1, 28*28), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
#     print(f"[Test] {epoch} Loss: {test_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
#         print('Saving..')
        state = {
            'model': model.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
#         if not os.path.isdir('models'):
#             os.mkdir('models')
#         torch.save(state, f'./models/{model_name}.pth')
        best_acc = acc

In [16]:
learning_rate = 0.0001
EPOCHS = 50

In [17]:
# H = [5, 10, 20, 100, 500]
H = [10, 20, 100, 500]

models_keys, models_func_keys

(['l_0.5', 'l_1', 'l_2', 'l_20', 'stereo', 'linear'],
 ['convex', 'invex', 'ordinary'])

In [18]:
# ! mkdir outputs/00.2_exp_acc/

In [19]:
exp_acc_vals = {}
# # Opening JSON file
with open("./outputs/00.2_exp_acc_data_v2.json", 'r') as f:
    exp_acc_vals = json.load(f)

In [22]:
exp_acc_vals

{'5': {'l_0.5': [71.49, 73.15, 73.99, 73.01, 72.52, 72.96, 74.34, 73.66],
  'l_1': [75.74, 75.33, 77.31, 76.17, 75.78, 77.11, 77.49, 76.82],
  'l_2': [76.58, 77.89, 78.05, 77.67, 76.98, 78.75, 78.94, 78.06],
  'l_20': [79.2, 79.7, 79.59, 78.49, 78.59, 78.85, 78.64, 75.73],
  'stereo': [81.37, 80.73, 80.81, 81.9, 81.09, 80.2, 81.34, 81.82],
  'linear': [81.22, 81.54, 81.04, 81.06, 81.5, 81.14, 80.99, 81.5],
  'convex': [77.49, 68.88, 78.08, 80.34, 78.26, 80.63, 81.74, 66.97],
  'invex': [88.48, 87.3, 86.87, 87.48, 88.06, 85.9, 87.63, 87.67],
  'ordinary': [82.05, 85.56, 85.61, 83.84, 83.62, 82.49, 82.51, 83.6]},
 '10': {'l_0.5': [77.87, 77.97, 78.2, 78.28, 77.56, 78.15, 78.32, 77.96],
  'l_1': [81.37, 81.15, 81.51, 81.42, 81.05, 81.0, 81.71, 81.46],
  'l_2': [82.3, 82.67, 82.32, 82.42, 81.88, 82.45, 82.64, 82.5],
  'l_20': [82.7, 82.4, 82.17, 82.97, 82.9, 82.09, 82.59, 81.31],
  'stereo': [84.94, 85.31, 85.15, 85.31, 84.9, 84.75, 84.63, 85.19],
  'linear': [84.43, 84.91, 84.94, 85.11, 8

In [None]:
# data_file = "./outputs/00.2_exp_acc_dict_v2.json"
SEEDS = [147, 258, 369, 741, 852, 963, 159, 357]

for h in H:
    acc_dict = {}
    
    for key, func_idx in zip(models_keys+models_func_keys, [0]*len(models_keys)+[1]*len(models_func_keys)):
        acc_dict[key] = []
        print("_________________________")
        print(f"Experimenting for {key} ; h:{h}")
        
        if h == 10 and key != "invex": continue

        for seed in tqdm(SEEDS):
            model_name = f"00.2_fmnist_{key}_h{h}_s{seed}"
            
            torch.manual_seed(seed)
            np.random.seed(seed)
            random.seed(seed)
            train_loader = data.DataLoader(dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)
            test_loader = data.DataLoader(dataset=test_dataset, num_workers=4, batch_size=batch_size, shuffle=False)

            if func_idx == 0:
                net = get_models(h, key=key).to(device)
            else:
                net = get_models_func(h, key=key).to(device)
            
            optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
            best_acc = -1
            for epoch in range(EPOCHS):
                train(epoch, net, optimizer)
                test(epoch, net, model_name)
                scheduler.step()
            acc_dict[key] += [float(best_acc)] ## add to the list
        
        exp_acc_vals[str(h)] = acc_dict
        
        ## Save it in the file.
        with open(f"./outputs/00.2_exp_acc_data_v2.json", "w") as f:
            json.dump(exp_acc_vals, f, indent=3)
        
        pass

_________________________
Experimenting for l_0.5 ; h:10
_________________________
Experimenting for l_1 ; h:10
_________________________
Experimenting for l_2 ; h:10
_________________________
Experimenting for l_20 ; h:10
_________________________
Experimenting for stereo ; h:10
_________________________
Experimenting for linear ; h:10
_________________________
Experimenting for convex ; h:10
_________________________
Experimenting for invex ; h:10


100%|██████████████████████████████████████████████████████| 8/8 [47:23<00:00, 355.49s/it]


_________________________
Experimenting for ordinary ; h:10
_________________________
Experimenting for l_0.5 ; h:20


100%|██████████████████████████████████████████████████████| 8/8 [32:36<00:00, 244.53s/it]


_________________________
Experimenting for l_1 ; h:20


100%|██████████████████████████████████████████████████████| 8/8 [33:26<00:00, 250.85s/it]


_________________________
Experimenting for l_2 ; h:20


100%|██████████████████████████████████████████████████████| 8/8 [34:39<00:00, 259.88s/it]


_________________________
Experimenting for l_20 ; h:20


100%|██████████████████████████████████████████████████████| 8/8 [33:29<00:00, 251.22s/it]


_________________________
Experimenting for stereo ; h:20


100%|██████████████████████████████████████████████████████| 8/8 [34:16<00:00, 257.07s/it]


_________________________
Experimenting for linear ; h:20


100%|██████████████████████████████████████████████████████| 8/8 [31:08<00:00, 233.57s/it]


_________________________
Experimenting for convex ; h:20


100%|██████████████████████████████████████████████████████| 8/8 [38:08<00:00, 286.05s/it]


_________________________
Experimenting for invex ; h:20


100%|██████████████████████████████████████████████████████| 8/8 [54:24<00:00, 408.08s/it]


_________________________
Experimenting for ordinary ; h:20


  0%|                                                               | 0/8 [00:00<?, ?it/s]

Yes Linear(in_features=784, out_features=500, bias=True)
Success
Yes Linear(in_features=500, out_features=784, bias=True)
Success
Yes Linear(in_features=784, out_features=1, bias=True)
Failed


 12%|██████▊                                               | 1/8 [05:27<38:10, 327.21s/it]

Yes Linear(in_features=784, out_features=500, bias=True)
Success
Yes Linear(in_features=500, out_features=784, bias=True)
Success
Yes Linear(in_features=784, out_features=1, bias=True)
Failed


 25%|█████████████▌                                        | 2/8 [10:54<32:44, 327.36s/it]

Yes Linear(in_features=784, out_features=500, bias=True)
Success
Yes Linear(in_features=500, out_features=784, bias=True)
Success
Yes Linear(in_features=784, out_features=1, bias=True)
Failed


 38%|████████████████████▎                                 | 3/8 [16:20<27:14, 326.87s/it]

Yes Linear(in_features=784, out_features=500, bias=True)
Success
Yes Linear(in_features=500, out_features=784, bias=True)
Success
Yes Linear(in_features=784, out_features=1, bias=True)
Failed


 50%|███████████████████████████                           | 4/8 [21:47<21:47, 326.77s/it]

Yes Linear(in_features=784, out_features=500, bias=True)
Success
Yes Linear(in_features=500, out_features=784, bias=True)
Success
Yes Linear(in_features=784, out_features=1, bias=True)
Failed


 62%|█████████████████████████████████▊                    | 5/8 [27:15<16:21, 327.29s/it]

Yes Linear(in_features=784, out_features=500, bias=True)
Success
Yes Linear(in_features=500, out_features=784, bias=True)
Success
Yes Linear(in_features=784, out_features=1, bias=True)
Failed


 75%|████████████████████████████████████████▌             | 6/8 [32:43<10:54, 327.38s/it]

Yes Linear(in_features=784, out_features=500, bias=True)
Success
Yes Linear(in_features=500, out_features=784, bias=True)
Success
Yes Linear(in_features=784, out_features=1, bias=True)
Failed


 88%|███████████████████████████████████████████████▎      | 7/8 [38:09<05:26, 326.86s/it]

Yes Linear(in_features=784, out_features=500, bias=True)
Success
Yes Linear(in_features=500, out_features=784, bias=True)
Success
Yes Linear(in_features=784, out_features=1, bias=True)
Failed


100%|██████████████████████████████████████████████████████| 8/8 [43:36<00:00, 327.09s/it]


_________________________
Experimenting for l_0.5 ; h:100


100%|██████████████████████████████████████████████████████| 8/8 [33:18<00:00, 249.76s/it]


_________________________
Experimenting for l_1 ; h:100


100%|██████████████████████████████████████████████████████| 8/8 [33:45<00:00, 253.24s/it]


_________________________
Experimenting for l_2 ; h:100


100%|██████████████████████████████████████████████████████| 8/8 [34:38<00:00, 259.85s/it]


_________________________
Experimenting for l_20 ; h:100


100%|██████████████████████████████████████████████████████| 8/8 [33:49<00:00, 253.69s/it]


_________________________
Experimenting for stereo ; h:100


100%|██████████████████████████████████████████████████████| 8/8 [34:17<00:00, 257.20s/it]


_________________________
Experimenting for linear ; h:100


100%|██████████████████████████████████████████████████████| 8/8 [31:01<00:00, 232.67s/it]


_________________________
Experimenting for convex ; h:100


100%|██████████████████████████████████████████████████████| 8/8 [50:45<00:00, 380.75s/it]


_________________________
Experimenting for invex ; h:100


100%|████████████████████████████████████████████████████| 8/8 [1:49:39<00:00, 822.40s/it]


_________________________
Experimenting for ordinary ; h:100


  0%|                                                               | 0/8 [00:00<?, ?it/s]

Yes Linear(in_features=784, out_features=500, bias=True)
Success
Yes Linear(in_features=500, out_features=784, bias=True)
Success
Yes Linear(in_features=784, out_features=1, bias=True)
Failed


 12%|██████▌                                             | 1/8 [12:48<1:29:40, 768.70s/it]

Yes Linear(in_features=784, out_features=500, bias=True)
Success
Yes Linear(in_features=500, out_features=784, bias=True)
Success
Yes Linear(in_features=784, out_features=1, bias=True)
Failed


 25%|█████████████                                       | 2/8 [25:37<1:16:53, 768.87s/it]

Yes Linear(in_features=784, out_features=500, bias=True)
Success
Yes Linear(in_features=500, out_features=784, bias=True)
Success
Yes Linear(in_features=784, out_features=1, bias=True)
Failed


 38%|███████████████████▌                                | 3/8 [38:26<1:04:03, 768.64s/it]

Yes Linear(in_features=784, out_features=500, bias=True)
Success
Yes Linear(in_features=500, out_features=784, bias=True)
Success
Yes Linear(in_features=784, out_features=1, bias=True)
Failed


 50%|███████████████████████████                           | 4/8 [51:14<51:13, 768.46s/it]

Yes Linear(in_features=784, out_features=500, bias=True)
Success
Yes Linear(in_features=500, out_features=784, bias=True)
Success
Yes Linear(in_features=784, out_features=1, bias=True)
Failed


 62%|████████████████████████████████▌                   | 5/8 [1:04:00<38:22, 767.64s/it]

Yes Linear(in_features=784, out_features=500, bias=True)
Success
Yes Linear(in_features=500, out_features=784, bias=True)
Success
Yes Linear(in_features=784, out_features=1, bias=True)
Failed


 75%|███████████████████████████████████████             | 6/8 [1:16:46<25:34, 767.19s/it]

Yes Linear(in_features=784, out_features=500, bias=True)
Success
Yes Linear(in_features=500, out_features=784, bias=True)
Success
Yes Linear(in_features=784, out_features=1, bias=True)
Failed


 88%|█████████████████████████████████████████████▌      | 7/8 [1:29:31<12:46, 766.55s/it]

Yes Linear(in_features=784, out_features=500, bias=True)
Success
Yes Linear(in_features=500, out_features=784, bias=True)
Success
Yes Linear(in_features=784, out_features=1, bias=True)
Failed


100%|████████████████████████████████████████████████████| 8/8 [1:42:18<00:00, 767.29s/it]


_________________________
Experimenting for l_0.5 ; h:500


100%|██████████████████████████████████████████████████████| 8/8 [44:02<00:00, 330.33s/it]


_________________________
Experimenting for l_1 ; h:500


 12%|██████▊                                               | 1/8 [04:38<32:26, 278.04s/it]

In [None]:
exp_acc_vals

In [None]:
# exp_acc_vals = \
# {'5': {'l_0.5': 67.05,
#   'l_1': 70.7,
#   'l_2': 78.15,
#   'l_inf': 79.52,
#   'stereo': 82.19,
#   'linear': 82.74,
#   'convex': 79.49,
#   'invex': 88.26,
#   'ordinary': 83.55},
#  '10': {'l_0.5': 72.08,
#   'l_1': 77.91,
#   'l_2': 82.35,
#   'l_inf': 83.98,
#   'stereo': 84.73,
#   'linear': 84.89,
#   'convex': 78.99,
#   'invex': 88.41,
#   'ordinary': 81.69}}

In [None]:
# data_file = "./outputs/00.2_exp_acc_dict.json"
# with open(data_file, "w") as f:
#     json.dump(exp_acc_vals, f, indent=3)

In [None]:
# # Opening JSON file
# with open(data_file, 'r') as f:
#     exp_acc_vals = json.load(f)

In [None]:
exp_acc_vals