In [None]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

In [None]:
import torch
import torch.nn as nn
import dtnnlib as dtnn

from torchvision import datasets, transforms as T
from torch.utils import data

In [None]:
torch.set_float32_matmul_precision('high')

In [None]:
from tqdm import tqdm
import os, time, sys, random
import json

In [None]:
mnist_transform = T.Compose([
    T.ToTensor(),
    T.Normalize(
        mean=[0.5,],
        std=[0.5,],
    ),
])

train_dataset = datasets.FashionMNIST(root="data/", train=True, download=True, transform=mnist_transform)
test_dataset = datasets.FashionMNIST(root="data/", train=False, download=True, transform=mnist_transform)

In [None]:
batch_size = 50
train_loader = data.DataLoader(dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)
test_loader = data.DataLoader(dataset=test_dataset, num_workers=4, batch_size=batch_size, shuffle=False)

In [None]:
device = torch.device("cuda:0")

In [None]:
criterion = nn.CrossEntropyLoss()

In [None]:
for xx, yy in train_loader:
    xx, yy = xx.to(device), yy.to(device)
    print(xx.shape, yy.shape)
    break

## Any function as metric

In [None]:
class FunctionDT(nn.Module):
    
    def __init__(self, input_dim, num_centers, func, inv_temp=1.):
        '''
        func [input_dim -> 1]
        '''
        super().__init__()
        self.input_dim = input_dim
        self.num_centers = num_centers
        self.func = func
        
        self.inv_temp = nn.Parameter(torch.ones(1)*np.log(inv_temp))
        
        self.centers = torch.randn(num_centers, input_dim)/3.
        self.centers = nn.Parameter(self.centers)
    
    def forward(self, x):
        z = x.unsqueeze(1) - self.centers.unsqueeze(0)
        bs, h, dim = z.shape
        z = z.view(-1, dim)
#         print(z.shape, self.func(z).shape)
        dists = self.func(z).view(bs, h)
#         print(dists.shape)
        dists = (1-dists)*torch.exp(self.inv_temp)
        
        return dists

In [None]:
!git clone https://github.com/tsumansapkota/Input-Invex-Neural-Network.git

In [None]:
sys.path.append("Input-Invex-Neural-Network")

In [None]:
from classes import DistanceRegressor, ConvexNN
from nflib.flows import SequentialFlow, ActNorm
import nflib.res_flow as irf

## Merge all models into single and benchmark

In [None]:
models_keys = ["l_0.5", "l_1", "l_2", "l_20", "stereo", "linear",]
def get_models(h = 5, key='linear'):
    I = 784
    layer1 = None
    if key == "l_0.5":
        layer1 = dtnn.DistanceTransform_Simple(I, h, p=0.5, bias=False)
    elif key == "l_1":
        layer1 = dtnn.DistanceTransform_Simple(I, h, p=1, bias=False)
    elif key == "l_2":
        layer1 = dtnn.DistanceTransform_Simple(I, h, bias=False)
    elif key == "l_20":
        layer1 = dtnn.DistanceTransform_Simple(I, h, p=20, bias=False)
    elif key == "stereo":
        layer1 = dtnn.iStereographicLinearTransform(I, h, bias=False)
    elif key == "linear":
        layer1 = nn.Linear(I, h, bias=False)
    else:
        raise KeyError()
        
    net = nn.Sequential(
        layer1,
        nn.BatchNorm1d(h, affine=False),
        nn.LayerNorm(h),
        nn.ELU(),
        nn.Linear(h, 10),
        )
    return net

In [None]:
models_keys = []
def get_models(h = 5, key='linear'):
    return nn.Identity()

In [None]:
def get_children(module):
    child = list(module.children())
    if len(child) == 0:
        return [module]
    children = []
    for ch in child:
        grand_ch = get_children(ch)
        children+=grand_ch
    return children

def remove_spectral_norm(model):
    for child in get_children(model):
        if hasattr(child, 'weight'):
            print("Yes", child)
            try:
                irf.remove_spectral_norm_conv(child)
                print("Success : irf conv")
            except Exception as e:
                print("Failed : irf conv")

            try:
                irf.remove_spectral_norm(child)
                print("Success : irf lin")
            except Exception as e:
                print("Failed : irf lin")

            try:
                nn.utils.remove_spectral_norm(child)
                print("Success : nn")
            except Exception as e:
                print("Failed : nn")
    return

In [None]:
models_func_keys = ["convex", "invex", "ordinary"]

def get_models_func(h = 500, func_h=500, key='ordinary'):
#     I = 784
    layer1 = None
    if key == "convex":
        layer1 = ConvexNN([784, func_h+2, func_h+1, 1])
    elif key == "invex":
        layer1 = nn.Sequential(
                    ActNorm(784),
                    irf.ResidualFlow(784, [func_h, func_h], activation=irf.LeakyReLU),
                    ActNorm(784),
                    DistanceRegressor(784),
                    )
    elif key == "ordinary":
        layer1 = nn.Sequential(
                    ActNorm(784),
                    irf.ResidualFlow(784, [func_h, func_h], activation=irf.LeakyReLU),
                    ActNorm(784),
                    DistanceRegressor(784),
                    )
        remove_spectral_norm(layer1)
    else:
        raise KeyError()
        
    net = nn.Sequential(
        FunctionDT(784, h, layer1),
        nn.BatchNorm1d(h, affine=False),
        nn.LayerNorm(h),
        nn.ELU(),
        nn.Linear(h, 10),
        )
    return net

In [None]:
## Following is copied from 
### https://github.com/kuangliu/pytorch-cifar/blob/master/main.py

# Training
def train(epoch, model, optimizer):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device).view(-1, 28*28), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    return

In [None]:
best_acc = -1
def test(epoch, model, model_name):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(device).view(-1, 28*28), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
    acc = 100.*correct/total
    if acc > best_acc:
        state = {
            'model': model.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        best_acc = acc

In [None]:
learning_rate = 0.0001
EPOCHS = 50

In [None]:
H = [5, 10, 20, 100, 500]

models_keys, models_func_keys

In [None]:
exp_acc_vals = {}

In [None]:
# # Opening JSON file
try:
    with open("./outputs/03_exp_acc_data_LN_BN.json", 'r') as f:
        exp_acc_vals = json.load(f)
except:
    pass

In [None]:
exp_acc_vals

In [None]:
SEEDS = [147, 258, 369, 741, 852, 963, 159, 357]

In [None]:
### Initialization
for h in H:
    acc_dict = {}
    for key, func_idx in zip(models_keys+models_func_keys, [0]*len(models_keys)+[1]*len(models_func_keys)):
        print(f"Checking for {key} ; h:{h}")
        try:
            results = exp_acc_vals[str(h)][str(key)]
            print(results)
            if len(results) == len(SEEDS):
                print("Results found complete")
                acc_dict[str(key)] = results
                continue
        except Exception as e:
            pass
        acc_dict[str(key)] = []
    exp_acc_vals[str(h)] = acc_dict
exp_acc_vals

In [None]:
# asdsad

In [None]:
for h in H:
    acc_dict = exp_acc_vals[str(h)]
    
    for key, func_idx in zip(models_keys+models_func_keys, [0]*len(models_keys)+[1]*len(models_func_keys)):
        print("_________________________")
        print(f"Experimenting for {key} ; h:{h}")
        
        try:
            results = exp_acc_vals[str(h)][str(key)]
            print(results)
            if len(results) == len(SEEDS):
                print("Results found complete")
                acc_dict[str(key)] = results
                continue
        except Exception as e:
            pass

        acc_dict[str(key)] = []
        for seed in tqdm(SEEDS):
            model_name = f"03_fmnist_{key}_h{h}_s{seed}"
            
            torch.manual_seed(seed)
            np.random.seed(seed)
            random.seed(seed)
            train_loader = data.DataLoader(dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)
            test_loader = data.DataLoader(dataset=test_dataset, num_workers=4, batch_size=batch_size, shuffle=False)

            if func_idx == 0:
                net = get_models(h, key=key).to(device)
            else:
                net = get_models_func(h, key=key).to(device)
            
            optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
            best_acc = -1
#             for epoch in tqdm(range(EPOCHS)):
            for epoch in range(EPOCHS):
#             for epoch in range(1):
                train(epoch, net, optimizer)
                test(epoch, net, model_name)
                scheduler.step()
            acc_dict[str(key)] += [float(best_acc)] ## add to the list
        
        exp_acc_vals[str(h)] = acc_dict
        
        # Save it in the file.
        with open(f"./outputs/03_exp_acc_data_LN_BN.json", "w") as f:
            json.dump(exp_acc_vals, f, indent=3)
        
        pass

In [None]:
# import json
# with open("./outputs/03_exp_acc_data_LN_BN.json", 'r') as f:
#     exp_acc_vals = json.load(f)

In [None]:
exp_acc_vals

In [None]:
final_stats = {}
for h in exp_acc_vals:
    final_stats[h] = {}
    for key in exp_acc_vals[h]:
        data = exp_acc_vals[h][key]
        mean = np.mean(data)
        std = np.std(data)
        maxm = np.max(data)
        final_stats[h][key] = [np.round(mean, 2), np.round(std, 2), np.round(maxm, 2)]
final_stats

In [None]:
for h in H:
    for key, func_idx in zip(models_keys+models_func_keys, [0]*len(models_keys)+[1]*len(models_func_keys)):
#     for key, func_idx in zip(models_func_keys, [1]*len(models_func_keys)):
        print("_________________________")
        print(f"Testing for {key} ; h:{h}")
        if func_idx == 0:
            net = get_models(h, key=key).to(device)
        else:
            net = get_models_func(h, key=key).to(device)
        print("Params:", sum([p.numel() for p in net.parameters()]))