In [None]:
import numpy as np
import copy
import matplotlib
import matplotlib.pyplot as plt
# from mpl_toolkits.mplot3d import Axes3D

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets, transforms as T
from torch.utils import data

In [None]:
from tqdm import tqdm
import os, time, sys
import json

In [None]:
sys.path.append("./Input-Invex-Neural-Network/")

In [None]:
import dtnnlib as dtnn

In [None]:
import nflib
from nflib.flows import SequentialFlow, ActNorm, ActNorm2D, BatchNorm1DFlow, BatchNorm2DFlow
import nflib.res_flow as irf

In [None]:
mnist_transform = T.Compose([
    T.ToTensor(),
    T.Normalize(
        mean=0.5,
        std=0.5,
    ),
])

train_dataset = datasets.FashionMNIST(root="data/", train=True, download=True, transform=mnist_transform)
test_dataset = datasets.FashionMNIST(root="data/", train=False, download=True, transform=mnist_transform)
# train_dataset = datasets.MNIST(root="data/", train=True, download=True, transform=mnist_transform)
# test_dataset = datasets.MNIST(root="data/", train=False, download=True, transform=mnist_transform)

In [None]:
train_dataset.data.shape

In [None]:
batch_size = 50
train_loader = data.DataLoader(dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)
test_loader = data.DataLoader(dataset=test_dataset, num_workers=4, batch_size=batch_size, shuffle=False)

In [None]:
device = torch.device("cuda:0")

In [None]:
for xx, yy in train_loader:
    print(xx.shape)
#     xx, yy = xx.view(-1,28*28).to(device), yy.to(device)
    xx, yy = xx.to(device), yy.to(device)
    print(xx.shape, yy.shape)
    break

## Train Test method

In [None]:
model_dir = "outputs/15.0_models"

In [None]:
criterion = nn.CrossEntropyLoss()

In [None]:
## Following is copied from 
### https://github.com/kuangliu/pytorch-cifar/blob/master/main.py

# Training
def train(epoch, model, optimizer):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader)):
#         inputs, targets = inputs.to(device).view(-1, 28*28), targets.to(device)
        inputs, targets = inputs.to(device), targets.to(device)
    
#         print(inputs.shape, targets.shape)
        
        ### Train with random image = "10" as class
#         inputs = torch.cat([inputs, torch.rand(batch_size//10, 28*28, dtype=inputs.dtype).to(device)*2-1], dim=0)
#         targets = torch.cat([targets, torch.ones(batch_size//10, dtype=targets.dtype).to(device)*10], dim=0)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    print(f"[Train] {epoch} Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    return

In [None]:
best_acc = -1
def test(epoch, model, model_name, save=False):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(tqdm(test_loader)):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
    print(f"[Test] {epoch} Loss: {test_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'model': model.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if save:
            if not os.path.isdir(f'{model_dir}'):
                os.mkdir(f'{model_dir}')
            torch.save(state, f'./{model_dir}/{model_name}.pth')
        best_acc = acc

## Models

In [None]:
class DistanceTransform_Epsilon(dtnn.DistanceTransformBase):
    
    def __init__(self, input_dim, num_centers, p=2, bias=False, epsilon=0.1, itemp=1):
        super().__init__(input_dim, num_centers, p=2)
        
        nc = num_centers
        if epsilon is not None:
            nc += 1
        self.scaler = nn.Parameter(torch.log(torch.ones(1, 1)*itemp))
        self.bias = nn.Parameter(torch.ones(1, nc)*0) if bias else None
        
        if epsilon is None:
            self.epsilon = None
        else:
            self.epsilon = dtnn.EMA(mu=epsilon)
        
    def forward(self, x):
        dists = super().forward(x)
        
        if self.epsilon is not None:
            #################################
#             dists = torch.cat([dists, torch.ones(len(x), 1).to(x)*self.epsilon], dim=1)
            #################################
            if self.training:
#                 mdist = dists.min().data
#                 mdist = dists.max().data
                mdist = dists.mean().data

                self.epsilon(mdist)
            dists = torch.cat([dists, torch.ones(len(x), 1).to(x)*self.epsilon.mu], dim=1)
            #################################
        
        ## scale the dists
        dists = 1-dists*torch.exp(self.scaler)
    
        if self.bias is not None: dists = dists+self.bias
        return dists

In [None]:
class LocalMLP_epsilonsoftmax(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, output_dim, epsilon=1.0, itemp=1):
        super().__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.new_hidden_dim = 0
        self.output_dim = output_dim
        
        self.layer0 = DistanceTransform_Epsilon(self.input_dim, self.hidden_dim, bias=False, epsilon=epsilon, itemp=itemp)
        
        hdim = self.hidden_dim
        if epsilon is not None:
            hdim += 1
            
        self.softmax = nn.Softmax(dim=-1)
        self.layer1 = nn.Linear(hdim, self.output_dim, bias=False)
    
        self.temp_maximum = None
        
    def forward(self, x):
        xo = self.layer0(x)
        ## dropout here creates 0 actv (is relatively high), hence serves as noise --> does not work for high values
#         xo = F.dropout(xo, p=0.001, training=self.training) ## use -inf as dropped value...
        xo = self.softmax(xo)
        self.temp_maximum = xo.data
        
        if self.training:
            self.layer1.weight.data[:,-1]*=0.
        xo = self.layer1(xo)
        return xo

## Benchmark Model Training

In [None]:
configs = {}
## global learning rate
learning_rate = 0.01
EPOCHS = 30
SEED = 2024

for center_lr_scaler in [1.0, 0.01]:
    for data_init in [False, True]:
        for hidden_units in [100, 500]:
            init = "rand"
            if data_init:
                init = "cent"
            model_name = f"dtesm_identity_I{init}_clrs{center_lr_scaler}_h{hidden_units}_mean"
            ########################################
            print(model_name)
            torch.manual_seed(SEED)
            
            flows = [
                irf.Flatten(img_size=[1, 28, 28]),
                    ]
            backbone = nn.Sequential(*flows).to(device)

            print("num_parameters", sum([p.numel() for p in backbone.parameters()]))

            yout = backbone(xx).data
            d = torch.cdist(yout, yout)
            n = d.shape[0]
            d = d.flatten()[1:].view(n-1, n+1)[:,:-1].reshape(n, n-1).cpu().numpy()

            print("embed dists -> min, max, mean, std", d.min(), d.max(), d.mean(), d.std())
            epsilon = d.mean()
            classifier = LocalMLP_epsilonsoftmax(784, hidden_units, 10, epsilon=epsilon, itemp=1.0).to(device)
            model = nn.Sequential(backbone, classifier)
                
            ### initialization of data
            if data_init:
                idx = torch.randperm(len(train_loader.dataset))[:hidden_units]
                source, target = train_dataset.data[idx].reshape(-1, 1, 28, 28).to(device), train_dataset.targets[idx]
                source = backbone(source.type(torch.float32)/128-1)
                classifier.layer0.centers.data = source

                targets = torch.zeros(len(target), 10)
                for i, t in enumerate(target):
                    targets[i, t] = 1.
                    targets[i,-1] = 0.005

                if classifier.layer0.epsilon is not None:
                    e = torch.zeros(1, 10)
                    targets = torch.cat([targets, e], dim=0)

                classifier.layer1.weight.data = targets.t().to(device)
            #################################################
            
            model = nn.Sequential(backbone, classifier)
            print("Testing at initialization..")
            test(-1, model, model_name="", save=False)
            
            p1, p2 = [], []
            for p in model.named_parameters():
                if "centers" in p[0]:
                    p1.append(p[1])
                else:
                    p2.append(p[1])
            params = [
                {"params": p1, "lr": learning_rate*center_lr_scaler},
                {"params": p2},
            ]
            ##################################################
            
            # optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
            optimizer = torch.optim.Adam(params, lr=learning_rate)
            warmup = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=0.3, total_iters=1)
            _scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS-1)
            scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[warmup, _scheduler], milestones=[2])

            best_acc = -1
            for epoch in range(EPOCHS):
                train(epoch, model, optimizer)
                test(epoch, model, model_name, save=True)
                scheduler.step()

In [None]:
exit(0)