In [1]:
import numpy as np
import copy
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

In [2]:
import torch
import torch.nn as nn
import dtnnlib as dtnn
# import resnet_cifar

from torchvision import datasets, transforms as T
from torch.utils import data

In [3]:
from tqdm import tqdm
import os, time, sys, random
import json

In [4]:
mnist_transform = T.Compose([
    T.ToTensor(),
    T.Normalize(
        mean=[0.5,],
        std=[0.5,],
    ),
])

train_dataset = datasets.FashionMNIST(root="../../../../_Datasets/", train=True, download=True, transform=mnist_transform)
test_dataset = datasets.FashionMNIST(root="../../../../_Datasets/", train=False, download=True, transform=mnist_transform)

In [5]:
# train_dataset.data = train_dataset.data.view(-1, 28*28)
# test_dataset.data = test_dataset.data.view(-1, 28*28)

In [6]:
batch_size = 50
train_loader = data.DataLoader(dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)
test_loader = data.DataLoader(dataset=test_dataset, num_workers=4, batch_size=batch_size, shuffle=False)

In [7]:
device = torch.device("cuda:1")

In [8]:
criterion = nn.CrossEntropyLoss()

In [9]:
for xx, yy in train_loader:
    xx, yy = xx.to(device), yy.to(device)
    print(xx.shape, yy.shape)
    break

torch.Size([50, 1, 28, 28]) torch.Size([50])


## Any function as metric

In [10]:
class FunctionDT(nn.Module):
    
    def __init__(self, input_dim, num_centers, func, inv_temp=0.):
        '''
        func [input_dim -> 1]
        '''
        super().__init__()
        self.input_dim = input_dim
        self.num_centers = num_centers
        self.func = func
        
        self.inv_temp = nn.Parameter(torch.ones(1)*inv_temp)
        
        self.centers = torch.randn(num_centers, input_dim)/3.
        self.centers = nn.Parameter(self.centers)
    
    def forward(self, x):
        z = x.unsqueeze(1) - self.centers.unsqueeze(0)
        dists = self.func(z).squeeze(-1)
        dists = -dists*torch.exp(self.inv_temp)
        
        ## COmment out (un-normalized)
#         dists = dists-dists.mean(dim=1, keepdim=True)
#         dists = dists/dists.std(dim=1, keepdim=True)

        return dists

In [11]:
from classes import DistanceRegressor, ConvexNN
from nflib.flows import SequentialFlow, ActNorm
import nflib.res_flow as irf

## Merge all models into single and benchmark

In [12]:
models_keys = ["l_0.5", "l_1", "l_2", "l_20", "stereo", "linear",]
def get_models(h = 5, key='linear'):
    I = 784
    layer1 = None
    if key == "l_0.5":
        layer1 = dtnn.DistanceTransform(I, h, p=0.5, bias=False)
    elif key == "l_1":
        layer1 = dtnn.DistanceTransform(I, h, p=1, bias=False)
    elif key == "l_2":
        layer1 = dtnn.DistanceTransform(I, h, bias=False)
    elif key == "l_20":
        layer1 = dtnn.DistanceTransform(I, h, p=20, bias=False)
    elif key == "stereo":
        layer1 = dtnn.StereographicTransform(I, h, bias=False)
    elif key == "linear":
        layer1 = nn.Linear(I, h, bias=False)
    else:
        raise KeyError()
        
    net = nn.Sequential(
        layer1,
        nn.BatchNorm1d(h, affine=False),
        nn.LayerNorm(h),
        nn.ELU(),
        nn.Linear(h, 10),
        )
    return net

In [13]:
models_func_keys = ["convex", "invex", "ordinary"]

def get_models_func(h = 500, func_h=500, key='ordinary'):
#     I = 784
    layer1 = None
    if key == "convex":
        layer1 = ConvexNN([784, func_h, 784, 1])
    elif key == "invex":
        layer1 = nn.Sequential(
                    ActNorm(784),
                    irf.ResidualFlow(784, [func_h], activation=irf.LeakyReLU),
                    ActNorm(784),
                    DistanceRegressor(784),
                    )
    elif key == "ordinary":
        layer1 = nn.Sequential(
                    ActNorm(784),
                    irf.ResidualFlow(784, [func_h], activation=irf.LeakyReLU),
                    ActNorm(784),
#                     DistanceRegressor(784),
                    nn.Linear(784, 1),
                    )
        irf.remove_spectral_norm_model(layer1)
    else:
        raise KeyError()
        
    net = nn.Sequential(
        FunctionDT(784, h, layer1),
        nn.BatchNorm1d(h, affine=False),
        nn.LayerNorm(h),
        nn.ELU(),
        nn.Linear(h, 10),
        )
    return net

In [14]:
## Following is copied from 
### https://github.com/kuangliu/pytorch-cifar/blob/master/main.py

# Training
def train(epoch, model, optimizer):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
#     for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader)):
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device).view(-1, 28*28), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
#     print(f"[Train] {epoch} Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    return

In [15]:
best_acc = -1
def test(epoch, model, model_name):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
#         for batch_idx, (inputs, targets) in enumerate(tqdm(test_loader)):
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(device).view(-1, 28*28), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
#     print(f"[Test] {epoch} Loss: {test_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
#         print('Saving..')
        state = {
            'model': model.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
#         if not os.path.isdir('models'):
#             os.mkdir('models')
#         torch.save(state, f'./models/{model_name}.pth')
        best_acc = acc

In [16]:
learning_rate = 0.0001
EPOCHS = 50

In [17]:
H = [5, 10, 20, 100, 500]

models_keys, models_func_keys

(['l_0.5', 'l_1', 'l_2', 'l_20', 'stereo', 'linear'],
 ['convex', 'invex', 'ordinary'])

In [18]:
exp_acc_vals = {}
# # Opening JSON file
with open("./outputs/00.2_exp_acc_data_v2_LN_BN.json", 'r') as f:
    exp_acc_vals = json.load(f)

In [19]:
exp_acc_vals

{'5': {'l_0.5': [70.65, 73.7, 73.06, 73.07, 72.21, 72.71, 73.52, 74.46],
  'l_1': [75.79, 76.54, 77.08, 76.93, 76.3, 77.42, 77.88, 77.89],
  'l_2': [76.84, 77.78, 78.56, 77.6, 77.28, 78.75, 78.92, 76.99],
  'l_20': [78.63, 79.43, 80.45, 79.09, 80.35, 79.59, 79.73, 77.87],
  'stereo': [81.33, 80.91, 79.68, 81.79, 80.97, 80.32, 81.24, 80.94],
  'linear': [81.54, 80.8, 80.96, 81.07, 81.5, 80.96, 80.86, 81.7],
  'convex': [88.48, 88.41, 88.31, 88.25, 88.1, 88.15, 88.27, 88.36],
  'invex': [88.96, 88.93, 88.3, 88.59, 88.59, 88.34, 88.34, 88.25],
  'ordinary': [83.45, 84.88, 84.77, 84.47, 84.72, 83.81, 84.66, 84.06]},
 '10': {'l_0.5': [77.68, 77.86, 78.36, 78.29, 77.42, 78.23, 78.79, 77.87],
  'l_1': [81.65, 81.45, 81.72, 81.74, 81.56, 81.77, 81.92, 81.84],
  'l_2': [82.71, 82.92, 82.6, 82.9, 82.72, 82.89, 83.22, 82.95],
  'l_20': [83.27, 83.17, 82.69, 83.62, 82.27, 82.91, 83.21, 82.07],
  'stereo': [84.76, 85.25, 85.09, 85.0, 85.11, 84.66, 84.87, 84.97],
  'linear': [84.54, 85.06, 85.07, 84

In [20]:
# ! mkdir outputs/00.2_exp_acc/

In [21]:
### Initialization
for h in H:
    acc_dict = {}
    for key, func_idx in zip(models_keys+models_func_keys, [0]*len(models_keys)+[1]*len(models_func_keys)):
        print(f"Checking for {key} ; h:{h}")
        try:
            results = exp_acc_vals[str(h)][str(key)]
            if len(results) == 8:
                print("Results found complete")
                acc_dict[str(key)] = results
                continue
        except Exception as e:
            pass
        acc_dict[str(key)] = []
    exp_acc_vals[str(h)] = acc_dict
exp_acc_vals

Checking for l_0.5 ; h:5
Results found complete
Checking for l_1 ; h:5
Results found complete
Checking for l_2 ; h:5
Results found complete
Checking for l_20 ; h:5
Results found complete
Checking for stereo ; h:5
Results found complete
Checking for linear ; h:5
Results found complete
Checking for convex ; h:5
Results found complete
Checking for invex ; h:5
Results found complete
Checking for ordinary ; h:5
Results found complete
Checking for l_0.5 ; h:10
Results found complete
Checking for l_1 ; h:10
Results found complete
Checking for l_2 ; h:10
Results found complete
Checking for l_20 ; h:10
Results found complete
Checking for stereo ; h:10
Results found complete
Checking for linear ; h:10
Results found complete
Checking for convex ; h:10
Results found complete
Checking for invex ; h:10
Results found complete
Checking for ordinary ; h:10
Results found complete
Checking for l_0.5 ; h:20
Results found complete
Checking for l_1 ; h:20
Results found complete
Checking for l_2 ; h:20
Resul

{'5': {'l_0.5': [70.65, 73.7, 73.06, 73.07, 72.21, 72.71, 73.52, 74.46],
  'l_1': [75.79, 76.54, 77.08, 76.93, 76.3, 77.42, 77.88, 77.89],
  'l_2': [76.84, 77.78, 78.56, 77.6, 77.28, 78.75, 78.92, 76.99],
  'l_20': [78.63, 79.43, 80.45, 79.09, 80.35, 79.59, 79.73, 77.87],
  'stereo': [81.33, 80.91, 79.68, 81.79, 80.97, 80.32, 81.24, 80.94],
  'linear': [81.54, 80.8, 80.96, 81.07, 81.5, 80.96, 80.86, 81.7],
  'convex': [88.48, 88.41, 88.31, 88.25, 88.1, 88.15, 88.27, 88.36],
  'invex': [88.96, 88.93, 88.3, 88.59, 88.59, 88.34, 88.34, 88.25],
  'ordinary': [83.45, 84.88, 84.77, 84.47, 84.72, 83.81, 84.66, 84.06]},
 '10': {'l_0.5': [77.68, 77.86, 78.36, 78.29, 77.42, 78.23, 78.79, 77.87],
  'l_1': [81.65, 81.45, 81.72, 81.74, 81.56, 81.77, 81.92, 81.84],
  'l_2': [82.71, 82.92, 82.6, 82.9, 82.72, 82.89, 83.22, 82.95],
  'l_20': [83.27, 83.17, 82.69, 83.62, 82.27, 82.91, 83.21, 82.07],
  'stereo': [84.76, 85.25, 85.09, 85.0, 85.11, 84.66, 84.87, 84.97],
  'linear': [84.54, 85.06, 85.07, 84

In [22]:
# data_file = "./outputs/00.2_exp_acc_dict_v2.json"
SEEDS = [147, 258, 369, 741, 852, 963, 159, 357]

# exp_acc_vals = {}
for h in H:
    acc_dict = exp_acc_vals[str(h)]
    
    for key, func_idx in zip(models_keys+models_func_keys, [0]*len(models_keys)+[1]*len(models_func_keys)):
        print("_________________________")
        print(f"Experimenting for {key} ; h:{h}")
        
        try:
            results = exp_acc_vals[str(h)][str(key)]
            print(results)
            if len(results) == len(SEEDS):
                print("Results found complete")
                acc_dict[str(key)] = results
                continue
        except Exception as e:
            pass

        acc_dict[str(key)] = []
        for seed in tqdm(SEEDS):
            model_name = f"00.2_fmnist_{key}_h{h}_s{seed}"
            
            torch.manual_seed(seed)
            np.random.seed(seed)
            random.seed(seed)
            train_loader = data.DataLoader(dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)
            test_loader = data.DataLoader(dataset=test_dataset, num_workers=4, batch_size=batch_size, shuffle=False)

            if func_idx == 0:
                net = get_models(h, key=key).to(device)
            else:
                net = get_models_func(h, key=key).to(device)
            
            optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
            best_acc = -1
            for epoch in range(EPOCHS):
#             for epoch in range(1):
                train(epoch, net, optimizer)
                test(epoch, net, model_name)
                scheduler.step()
            acc_dict[str(key)] += [float(best_acc)] ## add to the list
        
        exp_acc_vals[str(h)] = acc_dict
        
        # Save it in the file.
        with open(f"./outputs/00.2_exp_acc_data_v2_LN_BN.json", "w") as f:
            json.dump(exp_acc_vals, f, indent=3)
        
        pass

_________________________
Experimenting for l_0.5 ; h:5
[70.65, 73.7, 73.06, 73.07, 72.21, 72.71, 73.52, 74.46]
Results found complete
_________________________
Experimenting for l_1 ; h:5
[75.79, 76.54, 77.08, 76.93, 76.3, 77.42, 77.88, 77.89]
Results found complete
_________________________
Experimenting for l_2 ; h:5
[76.84, 77.78, 78.56, 77.6, 77.28, 78.75, 78.92, 76.99]
Results found complete
_________________________
Experimenting for l_20 ; h:5
[78.63, 79.43, 80.45, 79.09, 80.35, 79.59, 79.73, 77.87]
Results found complete
_________________________
Experimenting for stereo ; h:5
[81.33, 80.91, 79.68, 81.79, 80.97, 80.32, 81.24, 80.94]
Results found complete
_________________________
Experimenting for linear ; h:5
[81.54, 80.8, 80.96, 81.07, 81.5, 80.96, 80.86, 81.7]
Results found complete
_________________________
Experimenting for convex ; h:5
[88.48, 88.41, 88.31, 88.25, 88.1, 88.15, 88.27, 88.36]
Results found complete
_________________________
Experimenting for invex ; h:5
[

100%|████████████████████████████████████████████████████| 8/8 [1:38:54<00:00, 741.85s/it]


_________________________
Experimenting for invex ; h:100
[87.92, 88.3, 88.18, 88.25, 88.31, 88.45, 87.64, 88.48]
Results found complete
_________________________
Experimenting for ordinary ; h:100
[86.82, 87.19, 87.54, 87.08, 87.14, 87.05, 87.55, 86.32]
Results found complete
_________________________
Experimenting for l_0.5 ; h:500
[84.79, 84.72, 84.52, 84.7, 84.96, 84.59, 84.55, 85.13]
Results found complete
_________________________
Experimenting for l_1 ; h:500
[87.35, 86.89, 86.82, 86.94, 86.99, 87.09, 86.89, 87.21]
Results found complete
_________________________
Experimenting for l_2 ; h:500
[87.72, 87.58, 87.68, 87.74, 87.72, 87.51, 87.7, 87.71]
Results found complete
_________________________
Experimenting for l_20 ; h:500
[87.63, 87.74, 87.59, 87.44, 87.53, 87.61, 87.71, 87.79]
Results found complete
_________________________
Experimenting for stereo ; h:500
[89.29, 89.52, 89.28, 89.25, 89.29, 89.18, 89.3, 89.45]
Results found complete
_________________________
Experimenting

100%|███████████████████████████████████████████████████| 8/8 [7:10:27<00:00, 3228.47s/it]

_________________________
Experimenting for invex ; h:500
[88.67, 88.8, 89.3, 88.96, 88.93, 89.32, 89.48, 89.45]
Results found complete
_________________________
Experimenting for ordinary ; h:500
[87.62, 87.27, 87.31, 87.92, 87.65, 87.45, 87.49, 87.29]
Results found complete





In [23]:
exp_acc_vals

{'5': {'l_0.5': [70.65, 73.7, 73.06, 73.07, 72.21, 72.71, 73.52, 74.46],
  'l_1': [75.79, 76.54, 77.08, 76.93, 76.3, 77.42, 77.88, 77.89],
  'l_2': [76.84, 77.78, 78.56, 77.6, 77.28, 78.75, 78.92, 76.99],
  'l_20': [78.63, 79.43, 80.45, 79.09, 80.35, 79.59, 79.73, 77.87],
  'stereo': [81.33, 80.91, 79.68, 81.79, 80.97, 80.32, 81.24, 80.94],
  'linear': [81.54, 80.8, 80.96, 81.07, 81.5, 80.96, 80.86, 81.7],
  'convex': [88.48, 88.41, 88.31, 88.25, 88.1, 88.15, 88.27, 88.36],
  'invex': [88.96, 88.93, 88.3, 88.59, 88.59, 88.34, 88.34, 88.25],
  'ordinary': [83.45, 84.88, 84.77, 84.47, 84.72, 83.81, 84.66, 84.06]},
 '10': {'l_0.5': [77.68, 77.86, 78.36, 78.29, 77.42, 78.23, 78.79, 77.87],
  'l_1': [81.65, 81.45, 81.72, 81.74, 81.56, 81.77, 81.92, 81.84],
  'l_2': [82.71, 82.92, 82.6, 82.9, 82.72, 82.89, 83.22, 82.95],
  'l_20': [83.27, 83.17, 82.69, 83.62, 82.27, 82.91, 83.21, 82.07],
  'stereo': [84.76, 85.25, 85.09, 85.0, 85.11, 84.66, 84.87, 84.97],
  'linear': [84.54, 85.06, 85.07, 84

In [24]:
# exp_acc_vals = \
# {'5': {'l_0.5': 67.05,
#   'l_1': 70.7,
#   'l_2': 78.15,
#   'l_inf': 79.52,
#   'stereo': 82.19,
#   'linear': 82.74,
#   'convex': 79.49,
#   'invex': 88.26,
#   'ordinary': 83.55},
#  '10': {'l_0.5': 72.08,
#   'l_1': 77.91,
#   'l_2': 82.35,
#   'l_inf': 83.98,
#   'stereo': 84.73,
#   'linear': 84.89,
#   'convex': 78.99,
#   'invex': 88.41,
#   'ordinary': 81.69}}

In [25]:
# data_file = "./outputs/00.2_exp_acc_dict.json"
# with open(data_file, "w") as f:
#     json.dump(exp_acc_vals, f, indent=3)

In [26]:
# # Opening JSON file
# with open(data_file, 'r') as f:
#     exp_acc_vals = json.load(f)

In [27]:
# exp_acc_vals

In [28]:
# prev_vals = {'5': {'l_0.5': []},
#  '10': {'l_0.5': [],
#   'l_1': [],
#   'l_2': [],
#   'l_20': [],
#   'stereo': [],
#   'linear': [],
#   'convex': [],
#   'invex': [88.01, 87.73, 86.62, 87.5, 87.27, 87.92, 88.37, 89.68],
#   'ordinary': []},
#  '20': {'l_0.5': [80.29, 80.1, 80.35, 79.6, 79.93, 79.95, 80.84, 80.0],
#   'l_1': [82.78, 83.2, 83.43, 82.75, 82.61, 82.86, 83.2, 82.79],
#   'l_2': [83.68, 83.97, 84.16, 83.26, 83.51, 83.68, 83.67, 83.87],
#   'l_20': [83.53, 84.09, 84.34, 84.14, 83.77, 84.22, 83.96, 83.7],
#   'stereo': [86.59, 86.2, 86.3, 86.21, 86.5, 86.09, 86.47, 86.16],
#   'linear': [86.41, 86.53, 86.25, 86.67, 86.45, 86.5, 86.49, 86.44],
#   'convex': [80.18, 69.13, 79.31, 73.66, 77.92, 68.66, 70.28, 81.17],
#   'invex': [89.42, 88.23, 88.39, 86.84, 87.68, 88.31, 87.39, 89.18],
#   'ordinary': [85.39, 87.24, 86.32, 86.36, 86.03, 85.37, 86.5, 85.87]},
#  '100': {'l_0.5': [83.16, 83.22, 83.11, 83.21, 83.18, 83.14, 83.06, 83.17],
#   'l_1': [85.21, 85.25, 85.35, 85.39, 85.44, 85.52, 85.43, 85.26],
#   'l_2': [86.0, 86.21, 86.12, 86.31, 86.14, 86.21, 86.22, 86.08],
#   'l_20': [85.65, 86.02, 85.82, 85.85, 85.96, 85.66, 85.59, 85.88],
#   'stereo': [88.27, 88.37, 88.22, 88.11, 88.39, 88.43, 87.9, 88.14],
#   'linear': [88.46, 88.38, 88.29, 88.51, 88.45, 88.55, 88.46, 88.28],
#   'convex': [72.73, 71.7, 79.13, 75.96, 72.98, 67.99, 77.88, 77.5],
#   'invex': [87.92, 88.3, 88.18, 88.25, 88.31, 88.45, 87.64, 88.48],
#   'ordinary': [86.82, 87.19, 87.54, 87.08, 87.14, 87.05, 87.55, 86.32]},
#  '500': {'l_0.5': [84.79, 84.72, 84.52, 84.7, 84.96, 84.59, 84.55, 85.13],
#   'l_1': [87.35, 86.89, 86.82, 86.94, 86.99, 87.09, 86.89, 87.21],
#   'l_2': [87.72, 87.58, 87.68, 87.74, 87.72, 87.51, 87.7, 87.71],
#   'l_20': [87.63, 87.74, 87.59, 87.44, 87.53, 87.61, 87.71, 87.79],
#   'stereo': [89.29, 89.52, 89.28, 89.25, 89.29, 89.18, 89.3, 89.45]}}

In [29]:
# prev_vals

In [30]:
# exp_acc_vals

In [31]:
# ### merge the two vals.. keep which has full set of results
# combined_dict = {}
# for h in H:
#     combined_dict[str(h)] = {}
#     for key, func_idx in zip(models_keys+models_func_keys, [0]*len(models_keys)+[1]*len(models_func_keys)):
        
#         try:
#             v1 = exp_acc_vals[str(h)][str(key)]
#         except Exception as e:
#             v1 = []
            
#         try:
#             v2 = prev_vals[str(h)][str(key)]
#         except Exception as e:
#             v2 = []
            
#         v = []
#         if len(v1) > len(v2):
#             v = v1
#         else:
#             v = v2
            
# #         print(h, key,"\n", v1,"\n", v2)

#         combined_dict[str(h)][str(key)] = v
# combined_dict

In [32]:
# with open(f"./outputs/00.2_exp_acc_data_v2_LN_BN.json", "w") as f:
#     json.dump(combined_dict, f, indent=3)

In [33]:
# import json
# with open("./outputs/00.2_exp_acc_data_v2_LN_BN.json", 'r') as f:
#     exp_acc_vals = json.load(f)

In [34]:
exp_acc_vals

{'5': {'l_0.5': [70.65, 73.7, 73.06, 73.07, 72.21, 72.71, 73.52, 74.46],
  'l_1': [75.79, 76.54, 77.08, 76.93, 76.3, 77.42, 77.88, 77.89],
  'l_2': [76.84, 77.78, 78.56, 77.6, 77.28, 78.75, 78.92, 76.99],
  'l_20': [78.63, 79.43, 80.45, 79.09, 80.35, 79.59, 79.73, 77.87],
  'stereo': [81.33, 80.91, 79.68, 81.79, 80.97, 80.32, 81.24, 80.94],
  'linear': [81.54, 80.8, 80.96, 81.07, 81.5, 80.96, 80.86, 81.7],
  'convex': [88.48, 88.41, 88.31, 88.25, 88.1, 88.15, 88.27, 88.36],
  'invex': [88.96, 88.93, 88.3, 88.59, 88.59, 88.34, 88.34, 88.25],
  'ordinary': [83.45, 84.88, 84.77, 84.47, 84.72, 83.81, 84.66, 84.06]},
 '10': {'l_0.5': [77.68, 77.86, 78.36, 78.29, 77.42, 78.23, 78.79, 77.87],
  'l_1': [81.65, 81.45, 81.72, 81.74, 81.56, 81.77, 81.92, 81.84],
  'l_2': [82.71, 82.92, 82.6, 82.9, 82.72, 82.89, 83.22, 82.95],
  'l_20': [83.27, 83.17, 82.69, 83.62, 82.27, 82.91, 83.21, 82.07],
  'stereo': [84.76, 85.25, 85.09, 85.0, 85.11, 84.66, 84.87, 84.97],
  'linear': [84.54, 85.06, 85.07, 84

In [36]:
final_stats = {}
for h in exp_acc_vals:
    final_stats[h] = {}
    for key in exp_acc_vals[h]:
        data = exp_acc_vals[h][key]
        mean = np.mean(data)
        std = np.std(data)
        maxm = np.max(data)
        final_stats[h][key] = [np.round(mean, 2), np.round(std, 2), np.round(maxm, 2)]
final_stats

{'5': {'l_0.5': [72.92, 1.07, 74.46],
  'l_1': [76.98, 0.7, 77.89],
  'l_2': [77.84, 0.76, 78.92],
  'l_20': [79.39, 0.8, 80.45],
  'stereo': [80.9, 0.6, 81.79],
  'linear': [81.17, 0.33, 81.7],
  'convex': [88.29, 0.12, 88.48],
  'invex': [88.54, 0.26, 88.96],
  'ordinary': [84.35, 0.49, 84.88]},
 '10': {'l_0.5': [78.06, 0.41, 78.79],
  'l_1': [81.71, 0.14, 81.92],
  'l_2': [82.86, 0.18, 83.22],
  'l_20': [82.9, 0.49, 83.62],
  'stereo': [84.96, 0.18, 85.25],
  'linear': [84.92, 0.21, 85.28],
  'convex': [88.96, 0.2, 89.3],
  'invex': [89.29, 0.12, 89.51],
  'ordinary': [85.91, 0.26, 86.26]},
 '20': {'l_0.5': [80.13, 0.34, 80.84],
  'l_1': [82.95, 0.27, 83.43],
  'l_2': [83.72, 0.26, 84.16],
  'l_20': [83.97, 0.26, 84.34],
  'stereo': [86.32, 0.17, 86.59],
  'linear': [86.47, 0.11, 86.67],
  'convex': [88.9, 0.09, 88.99],
  'invex': [88.18, 0.81, 89.42],
  'ordinary': [86.14, 0.58, 87.24]},
 '100': {'l_0.5': [83.16, 0.05, 83.22],
  'l_1': [85.36, 0.1, 85.52],
  'l_2': [86.16, 0.09, 86

In [53]:
# 72.92$\pm$1.07 (\textcolor{blue}{74.46})
T = np.empty([9, 5], dtype=object)
for i, h in enumerate(exp_acc_vals):
    for j, key in enumerate(exp_acc_vals[h]):
        data = exp_acc_vals[h][key]
        mean = np.mean(data)
        std = np.std(data)
        maxm = np.max(data)
#         print(f"{h}-{key} : {np.round(mean, 2)}$\pm${np.round(std, 2)}")
#         print(f"{h}-{key} : {np.round(mean, 2)}$\pm${np.round(std, 2)} (\\textcolor{{blue}}{{{np.round(maxm, 2)}}})")
        T[j, i] = f"{np.round(mean, 2)}$\pm${np.round(std, 2)} (\\textcolor{{blue}}{{{np.round(maxm, 2)}}})"
T

array([['72.92$\\pm$1.07 (\\textcolor{blue}{74.46})',
        '78.06$\\pm$0.41 (\\textcolor{blue}{78.79})',
        '80.13$\\pm$0.34 (\\textcolor{blue}{80.84})',
        '83.16$\\pm$0.05 (\\textcolor{blue}{83.22})',
        '84.74$\\pm$0.2 (\\textcolor{blue}{85.13})'],
       ['76.98$\\pm$0.7 (\\textcolor{blue}{77.89})',
        '81.71$\\pm$0.14 (\\textcolor{blue}{81.92})',
        '82.95$\\pm$0.27 (\\textcolor{blue}{83.43})',
        '85.36$\\pm$0.1 (\\textcolor{blue}{85.52})',
        '87.02$\\pm$0.17 (\\textcolor{blue}{87.35})'],
       ['77.84$\\pm$0.76 (\\textcolor{blue}{78.92})',
        '82.86$\\pm$0.18 (\\textcolor{blue}{83.22})',
        '83.72$\\pm$0.26 (\\textcolor{blue}{84.16})',
        '86.16$\\pm$0.09 (\\textcolor{blue}{86.31})',
        '87.67$\\pm$0.08 (\\textcolor{blue}{87.74})'],
       ['79.39$\\pm$0.8 (\\textcolor{blue}{80.45})',
        '82.9$\\pm$0.49 (\\textcolor{blue}{83.62})',
        '83.97$\\pm$0.26 (\\textcolor{blue}{84.34})',
        '85.8$\\pm$0.15 (\\tex

In [58]:
for i in range(T.shape[0]):
    for j in range(T.shape[1]):
        print(T[i,j], end=" & ")
    print()

72.92$\pm$1.07 (\textcolor{blue}{74.46}) & 78.06$\pm$0.41 (\textcolor{blue}{78.79}) & 80.13$\pm$0.34 (\textcolor{blue}{80.84}) & 83.16$\pm$0.05 (\textcolor{blue}{83.22}) & 84.74$\pm$0.2 (\textcolor{blue}{85.13}) & 
76.98$\pm$0.7 (\textcolor{blue}{77.89}) & 81.71$\pm$0.14 (\textcolor{blue}{81.92}) & 82.95$\pm$0.27 (\textcolor{blue}{83.43}) & 85.36$\pm$0.1 (\textcolor{blue}{85.52}) & 87.02$\pm$0.17 (\textcolor{blue}{87.35}) & 
77.84$\pm$0.76 (\textcolor{blue}{78.92}) & 82.86$\pm$0.18 (\textcolor{blue}{83.22}) & 83.72$\pm$0.26 (\textcolor{blue}{84.16}) & 86.16$\pm$0.09 (\textcolor{blue}{86.31}) & 87.67$\pm$0.08 (\textcolor{blue}{87.74}) & 
79.39$\pm$0.8 (\textcolor{blue}{80.45}) & 82.9$\pm$0.49 (\textcolor{blue}{83.62}) & 83.97$\pm$0.26 (\textcolor{blue}{84.34}) & 85.8$\pm$0.15 (\textcolor{blue}{86.02}) & 87.63$\pm$0.11 (\textcolor{blue}{87.79}) & 
80.9$\pm$0.6 (\textcolor{blue}{81.79}) & 84.96$\pm$0.18 (\textcolor{blue}{85.25}) & 86.32$\pm$0.17 (\textcolor{blue}{86.59}) & 88.23$\pm$0.17 

In [60]:
import torch

In [80]:
def inv_streo(x):
    sqnorm = (x**2).sum(dim=1, keepdim=True) ## l2 norm squared
    x = x*2/(sqnorm+1)
    new_dim = (sqnorm-1)/(sqnorm+1)
    x = torch.cat((x, new_dim), dim=1)
    return x

In [81]:
inv_streo(torch.randn(1, 2))

tensor([[-0.5693,  0.6854,  0.4539]])

In [82]:
def streo(x):
    x_ = x[:, :-1]
    new_dim = x[:, -1:]
    sqnorm = (1+new_dim)/(1-new_dim)
    x_ = x_/2*(sqnorm+1)
    return x_

In [84]:
x = torch.randn(1,2)
print(x)
z = inv_streo(x)
print(z)
streo(z)

tensor([[-1.8530,  0.0267]])
tensor([[-0.8358,  0.0120,  0.5490]])


tensor([[-1.8530,  0.0267]])