In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import os, sys, pathlib, random, time, pickle, copy
from tqdm import tqdm

In [2]:
device = torch.device("cuda:0")
# device = torch.device("cpu")

In [3]:
import torch.optim as optim
from torch.utils import data

In [4]:
import nflib
from nflib.flows import SequentialFlow, ActNorm, ActNorm2D, BatchNorm1DFlow, BatchNorm2DFlow
import nflib.coupling_flows as icf
import nflib.res_flow as irf

In [5]:
seed = 123

np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

### Datasets

In [6]:
train_transform = transforms.Compose([
    transforms.ToTensor(),
])
test_transform = transforms.Compose([
    transforms.ToTensor(),
])

# train_dataset = datasets.MNIST(root="./data/", train=True, download=True, transform=train_transform)
# test_dataset = datasets.MNIST(root="./data/", train=False, download=True, transform=test_transform)

train_dataset = datasets.FashionMNIST(root="./data/", train=True, download=True, transform=train_transform)
test_dataset = datasets.FashionMNIST(root="./data/", train=False, download=True, transform=test_transform)

In [7]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=50, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=50, shuffle=False, num_workers=2)

In [8]:
for xx, yy in train_loader:
    break

In [9]:
xx.shape

torch.Size([50, 1, 28, 28])

### Model

In [10]:
actf = irf.Swish
# Norm1D = BatchNorm1DFlow
# Norm2D = BatchNorm2DFlow
Norm1D = ActNorm
Norm2D = ActNorm2D
flows = [
    Norm2D(1), ## can remove
    irf.ConvResidualFlow([1, 28, 28], [16], kernels=3, activation=actf),
    irf.InvertiblePooling(2),
    Norm2D(4),
    irf.ConvResidualFlow([4, 14, 14], [64], kernels=3, activation=actf),
    irf.InvertiblePooling(2),
    Norm2D(16),
    irf.ConvResidualFlow([16, 7, 7], [64, 64], kernels=3, activation=actf),
    Norm2D(16),
    irf.Flatten(img_size=[16, 7, 7]),
    Norm1D(16*7*7),
        ]
        

# backbone = SequentialFlow(flows)
backbone = nn.Sequential(*flows)

In [11]:
backbone.to(device)

Sequential(
  (0): ActNorm2D()
  (1): ConvResidualFlow(
    (resblock): ModuleList(
      (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): Swish()
      (2): Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (2): InvertiblePooling()
  (3): ActNorm2D()
  (4): ConvResidualFlow(
    (resblock): ModuleList(
      (0): Conv2d(4, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): Swish()
      (2): Conv2d(64, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (5): InvertiblePooling()
  (6): ActNorm2D()
  (7): ConvResidualFlow(
    (resblock): ModuleList(
      (0): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): Swish()
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): Swish()
      (4): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (8): ActNorm2D()
  (9): Flatten()
  (10): ActNorm()
)

In [12]:
backbone(xx.to(device)).shape

torch.Size([50, 784])

In [13]:
def get_children(module):
    child = list(module.children())
    if len(child) == 0:
        return [module]
    children = []
    for ch in child:
        grand_ch = get_children(ch)
        children+=grand_ch
    return children

def remove_spectral_norm(model):
    for child in get_children(model):
        if hasattr(child, 'weight'):
            print("Yes", child)
            try:
                irf.remove_spectral_norm_conv(child)
                print("Success : irf conv")
            except Exception as e:
#                     print(e)
                print("Failed : irf conv")

            try:
                irf.remove_spectral_norm(child)
                print("Success : irf lin")
            except Exception as e:
#                     print(e)
                print("Failed : irf lin")

            try:
                nn.utils.remove_spectral_norm(child)
                print("Success : nn")
            except Exception as e:
#                     print(e)
                print("Failed : nn")
    return

In [14]:
# remove_spectral_norm(backbone)

In [15]:
print("number of params: ", sum(p.numel() for p in backbone.parameters()))

number of params:  62067


In [16]:
for xx, yy in train_loader:
    tt = backbone(xx.to(device))
    print(xx.shape, tt.shape)
    break

torch.Size([50, 1, 28, 28]) torch.Size([50, 784])


In [17]:
class ConnectedClassifier_Linear(nn.Module):

    def __init__(self,input_dim, num_sets, output_dim, inv_temp=1):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_sets = num_sets
        self.inv_temp = nn.Parameter(torch.ones(1)*inv_temp)

        self.linear = nn.Linear(input_dim, num_sets)

        init_val = torch.randn(num_sets, output_dim)
        for ns in range(num_sets):
            init_val[ns, ns%output_dim] = 5
        self.cls_weight = nn.Parameter(init_val)

        self.cls_confidence = None


    def forward(self, x, hard=False):
        x = self.linear(x)*torch.exp(self.inv_temp)
        if hard:
            x = torch.softmax(x*1e5, dim=1)
        else:
            x = torch.softmax(x, dim=1)
#             x = torch.softmax(x*self.inv_temp, dim=1)
        self.cls_confidence = x
#         c = torch.softmax(self.cls_weight, dim=1)
        c = self.cls_weight
        return x@c ## since both are normalized, it is also normalized

In [18]:
class ConnectedClassifier_Distance(nn.Module):

    def __init__(self,input_dim, num_sets, output_dim, inv_temp=1):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_sets = num_sets
        self.inv_temp = nn.Parameter(torch.ones(1)*inv_temp)

        self.centers = nn.Parameter(torch.rand(num_sets, input_dim)*2-1)
        self.bias = nn.Parameter(torch.zeros(1, num_sets))
#         self.cls_weight = nn.Parameter(torch.ones(num_sets, output_dim)/output_dim)

        init_val = torch.randn(num_sets, output_dim)
        for ns in range(num_sets):
            init_val[ns, ns%output_dim] = 5
        self.cls_weight = nn.Parameter(init_val)

        self.cls_confidence = None


    def forward(self, x, hard=False):

        dists = torch.cdist(x, self.centers)
        ### correction to make diagonal of unit square 1 in nD space
        dists = dists/np.sqrt(self.input_dim) + self.bias
        dists = dists*torch.exp(self.inv_temp)
        if hard:
            x = torch.softmax(-dists*1e5, dim=1)
        else:
            x = torch.softmax(-dists, dim=1)
#             x = torch.softmax(-dists*self.inv_temp, dim=1)
        self.cls_confidence = x
        c = self.cls_weight
        return x@c ## since both are normalized, it is also normalized

In [19]:
train_loader.dataset

Dataset FashionMNIST
    Number of datapoints: 60000
    Root location: ./data/
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
           )

In [20]:
train_loader.dataset.transforms

StandardTransform
Transform: Compose(
               ToTensor()
           )

In [21]:
#### C10
classifier = ConnectedClassifier_Distance(784, 100, 10, inv_temp=3.)
#### for MLP based classification
# classifier = nn.Sequential(nn.Linear(784, 100), nn.SELU(), nn.Linear(100, 10))

classifier = classifier.to(device)

In [22]:
print("number of params: ", sum(p.numel() for p in backbone.parameters()))
print("number of params: ", sum(p.numel() for p in classifier.parameters()))

number of params:  62067
number of params:  79501


In [23]:
model = nn.Sequential(backbone, classifier).to(device)

In [24]:
print("number of params: ", sum(p.numel() for p in model.parameters()))

number of params:  141568


## Training

In [25]:
model_name = 'fmnist_actnorm_multi_invex'
# model_name = 'fmnist_batchnorm_multi_invex'

In [26]:
EPOCHS = 50
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

In [27]:
## Following is copied from 
### https://github.com/kuangliu/pytorch-cifar/blob/master/main.py

# Training
def train(epoch):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader)):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    print(f"[Train] {epoch} Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    return

In [28]:
best_acc = -1
def test(epoch):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(tqdm(test_loader)):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
    print(f"[Test] {epoch} Loss: {test_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'model': model.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('models'):
            os.mkdir('models')
        torch.save(state, f'./models/{model_name}.pth')
        best_acc = acc

In [29]:
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
resume = False

if resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('./models'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load(f'./models/{model_name}.pth')
    model.load_state_dict(checkpoint['model'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

In [30]:
# ### Train the whole damn thing

# for epoch in range(start_epoch, start_epoch+EPOCHS):
#     train(epoch)
#     test(epoch)
#     scheduler.step()

In [31]:
best_acc    #90.9 without first bn... with 90 ish is similar

-1

In [32]:
classifier.inv_temp

Parameter containing:
tensor([3.], device='cuda:0', requires_grad=True)

In [33]:
checkpoint = torch.load(f'./models/{model_name}.pth')
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch']

best_acc, start_epoch

(89.68, 17)

In [34]:
checkpoint.keys()

dict_keys(['model', 'acc', 'epoch'])

In [35]:
model_name

'fmnist_actnorm_multi_invex'

In [36]:
asdasd

NameError: name 'asdasd' is not defined

In [37]:
model.load_state_dict(checkpoint['model'])

<All keys matched successfully>

### Hard test accuracy with count per classifier

In [38]:
model.eval()
print("Testing")

Testing


In [39]:
backbone, classifier = model[0], model[1]

In [40]:
classifier

ConnectedClassifier_Distance()

In [41]:
test_count = 0
test_acc = 0
set_count = torch.zeros(classifier.num_sets).to(device)
model.eval()
for xx, yy in tqdm(test_loader):
    xx, yy = xx.to(device), yy.to(device)
    with torch.no_grad():
        yout = classifier(backbone(xx), hard=True)
        set_indx, count = torch.unique(torch.argmax(classifier.cls_confidence, dim=1), return_counts=True) 
        set_count[set_indx] += count
    outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
    correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
    test_acc += correct
    test_count += len(xx)

acc = float(test_acc)/test_count*100
print(f'Hard Acc:{acc:.2f}%')

100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 407.76it/s]

Hard Acc:89.31%





In [42]:
## Everything great... collect the images and labes for each cluster

In [43]:
test_count = 0
test_acc = 0
set_count = torch.zeros(classifier.num_sets).to(device)
set_acc = torch.zeros(classifier.num_sets).to(device)
for xx, yy in tqdm(test_loader):
    xx, yy = xx.to(device), yy.to(device)
    with torch.no_grad():
        yout = classifier(backbone(xx), hard=True)
        
    cls_indx = torch.argmax(classifier.cls_confidence, dim=1)
    set_indx, count = torch.unique(cls_indx, return_counts=True) 
    set_count[set_indx] += count
    
    outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
    correct = (outputs == yy.data.cpu().numpy()).astype(float)
    
    ### class_index has 100 possible values
    for i, c in enumerate(correct):
        set_acc[cls_indx[i]] += c
    
#     print(set_acc.sum(), set_count.sum())
#     break
    test_acc += correct.sum()
    test_count += len(xx)

print(f'Hard Acc:{float(test_acc)/test_count*100:.2f}%')
print(set_count.type(torch.long).tolist())

100%|██████████████████████████████████████████████████| 200/200 [00:00<00:00, 332.36it/s]

Hard Acc:89.31%
[0, 0, 0, 8, 0, 0, 0, 0, 0, 963, 0, 995, 0, 0, 0, 0, 0, 0, 0, 0, 12, 0, 0, 0, 17, 998, 0, 0, 0, 0, 0, 0, 0, 1078, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 974, 0, 0, 2, 2, 0, 1011, 0, 1016, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 294, 0, 0, 0, 0, 0, 0, 0, 1081, 0, 22, 0, 0, 0, 0, 0, 0, 0, 0, 0, 471, 1054, 0, 0]





In [44]:
print(f"idx,\tcls,\tacc \ttot,\tpercent%")
for i, (cnt, acc, cls) in enumerate(zip(set_count.type(torch.long).tolist(),
                                   (set_acc/set_count).tolist(),
                                   torch.argmax(classifier.cls_weight, dim=1).tolist())):
    if cnt == 0: continue
    print(f"{i},\t{cls},\t{int(np.ceil(acc*cnt))}\t{cnt},\t{acc*100:.2f}%")

idx,	cls,	acc 	tot,	percent%
3,	3,	6	8,	75.00%
9,	9,	942	963,	97.72%
11,	1,	982	995,	98.59%
20,	0,	7	12,	58.33%
24,	4,	7	17,	35.29%
25,	5,	973	998,	97.49%
33,	3,	924	1078,	85.62%
42,	2,	0	2,	0.00%
52,	2,	830	974,	85.22%
55,	5,	1	2,	50.00%
56,	6,	1	2,	50.00%
58,	8,	961	1011,	95.05%
60,	0,	841	1016,	82.68%
76,	6,	189	294,	64.29%
84,	4,	869	1081,	80.39%
86,	6,	13	22,	59.09%
96,	6,	411	471,	87.05%
97,	7,	981	1054,	92.98%


In [46]:
for i, c in enumerate(train_dataset.classes):
    print(i, c)

0 T-shirt/top
1 Trouser
2 Pullover
3 Dress
4 Coat
5 Sandal
6 Shirt
7 Sneaker
8 Bag
9 Ankle boot


## collect data points for all sets

In [None]:
idx_data_dict = {}

for i, (cnt, acc, cls) in enumerate(zip(set_count.type(torch.long).tolist(),
                                   (set_acc/set_count).tolist(),
                                   torch.argmax(classifier.cls_weight, dim=1).tolist())):
    if cnt == 0: continue
    idx_data_dict[i] = {"cls":cls, "correct":int(acc*cnt), "total":cnt, "acc":acc*100, "xs":[], "ys":[]}
    
idx_data_dict

In [None]:
for xx, yy in tqdm(test_loader):
    xx, yy = xx.to(device), yy.to(device)
    with torch.no_grad():
        yout = classifier(backbone(xx), hard=True)
        
    cls_indx = torch.argmax(classifier.cls_confidence, dim=1).cpu().numpy()
    
    for i in range(len(cls_indx)):
        regi = cls_indx[i]
        x = xx[i].cpu()
        y = yy[i].cpu()
        
        idx_data_dict[regi]["xs"].append(x)
        idx_data_dict[regi]["ys"].append(y)
    pass

In [None]:
# for key, val in idx_data_dict.items():
#     print(key)
#     xs = torch.cat(val['xs'], dim=0).to(device)
#     ys = torch.LongTensor(val['ys'])
    
#     del xs

In [None]:
classifier.centers.shape

In [None]:
invbackbone = SequentialFlow([*backbone])
invbackbone

In [None]:
classifier.input_dim

In [None]:
classifier.centers.data.shape

In [None]:
with torch.no_grad():
    xcenter = invbackbone.inverse(classifier.centers.data)

In [None]:
xcenter.shape

In [None]:
i=0

In [None]:
dmean, dstd = np.array([0, 0, 0]) , np.array([1, 1, 1]) ##f/mnist
# dmean, dstd = np.array([0.4914, 0.4822, 0.4465]) , np.array([0.2023, 0.1994, 0.2010])
# dmean, dstd = np.array([0.5071, 0.4865, 0.4409]) , np.array([0.2009, 0.1984, 0.2023]) ##cifar100

In [None]:
img = xcenter[i].cpu().permute(1,2,0).numpy()*dstd + dmean
img = (img - img.min()) / (img.max()- img.min())
# img = np.clip(img, 0, 1)

plt.imshow(img)
print(i)
i+=1

#### test reversibility of the network

In [None]:
for xx, yy in test_loader:
    break

In [None]:
# i = 0

In [None]:
# ximg = xx[i].cpu().permute(1,2,0).numpy()*dstd + dmean
# i+=1
# plt.imshow(ximg)

In [None]:
xx = xx.to(device)

In [None]:
backbone(xx).data.shape

In [None]:
with torch.no_grad():
    xx_ = invbackbone.inverse(invbackbone(xx).data)

In [None]:
i = 0

In [None]:
ximg_ = xx_[i].cpu().permute(1,2,0).numpy()*dstd + dmean
plt.imshow(ximg_)
plt.show()
print("Reconstruction")
ximg = xx[i].cpu().permute(1,2,0).numpy()*dstd + dmean
plt.imshow(ximg)
plt.show()
print("True")
print(i)
i+=1

In [None]:
ximg_.max(), ximg_.min()

In [None]:
classifier.centers.max()

## Now plot all centers as per index

In [None]:
with torch.no_grad():
    xcenter = invbackbone.inverse(classifier.centers.data)

In [None]:
os.makedirs(f"invex_out/multiinvex_centers_viz/{model_name}", exist_ok=True)

In [None]:
# imgsize = [3,32,32]
imgsize = [1,28,28]
    
if isinstance(train_dataset, datasets.FashionMNIST):
    train_dataset.classes[0] = 'T-shirt'

In [None]:
train_dataset.classes

In [None]:
with torch.no_grad():
    for key, val in idx_data_dict.items():
        print(key, val['total'])
        xs = torch.stack(val['xs'], dim=0).to(device)

#         print(xs.shape)
        zs = backbone(xs)

        ### find the medoid in z-space
        idx = torch.cdist(zs, zs).sum(dim=0).argmin()
        img = xs[idx].cpu().view(*imgsize).permute(1,2,0).numpy()*dstd + dmean
        img = (img - img.min()) / (img.max()- img.min())
        plt.imshow(img)
        plt.axis('off')
        plt.savefig(f"invex_out/multiinvex_centers_viz/{model_name}/zmedoid_({key})_{train_dataset.classes[val['cls']]}.png", bbox_inches='tight')
        plt.close()
        
        ### find nearest in z-space
        idx = torch.norm(zs-classifier.centers[key].view(1, np.prod(imgsize)), dim=1).argmin()
        img = xs[idx].cpu().view(*imgsize).permute(1,2,0).numpy()*dstd + dmean
        img = (img - img.min()) / (img.max()- img.min())
        plt.imshow(img)
        plt.axis('off')
        plt.savefig(f"invex_out/multiinvex_centers_viz/{model_name}/znearest_({key})_{train_dataset.classes[val['cls']]}.png", bbox_inches='tight')
        plt.close()


        ### find the medoid in x-space
        xs = xs.view(-1, np.prod(imgsize))

        ys = torch.LongTensor(val['ys'])

        ## save center
        img = xcenter[key].cpu().permute(1,2,0).numpy()*dstd + dmean
        img = (img - img.min()) / (img.max()- img.min())
        plt.imshow(img)
        plt.axis('off')
        plt.savefig(f"invex_out/multiinvex_centers_viz/{model_name}/xcenter_({key})_{train_dataset.classes[val['cls']]}.png", bbox_inches='tight')
        plt.close()

        ## find medoid of the data
        print("xMedoid", idx:=torch.cdist(xs, xs).sum(dim=0).argmin())

        img = xs[idx].cpu().view(*imgsize).permute(1,2,0).numpy()*dstd + dmean
        img = (img - img.min()) / (img.max()- img.min())
        plt.imshow(img)
        plt.axis('off')
        plt.savefig(f"invex_out/multiinvex_centers_viz/{model_name}/xmedoid_({key})_{train_dataset.classes[val['cls']]}.png", bbox_inches='tight')
        plt.close()

        ## find closest data point
        print("xNearest", idx:=torch.norm(xs-xcenter[key].view(1, np.prod(imgsize)), dim=1).argmin())

        img = xs[idx].cpu().view(*imgsize).permute(1,2,0).numpy()*dstd + dmean
        img = (img - img.min()) / (img.max()- img.min())
        plt.imshow(img)
        plt.axis('off')
        plt.savefig(f"invex_out/multiinvex_centers_viz/{model_name}/xnearest_({key})_{train_dataset.classes[val['cls']]}.png", bbox_inches='tight')
        plt.close()

        del xs, ys

### Verify Reversibility

In [None]:
for xx, yy in train_loader:
    break

In [None]:
invback = SequentialFlow(backbone).to(device)
# invback.train()
invback.eval()

In [None]:
torch.cuda.empty_cache()

In [None]:
with torch.no_grad():
    yys_ = invback.forward_intermediate(xx.to(device))
    xxs_ = invback.inverse_intermediate(yys_[-1].to(device))

In [None]:
len(invback.flows), len(xxs_), len(yys_)

In [None]:
for i in range(len(invback.flows)):
    print(invback.flows[-i-1])
    print("f", yys_[-i-1].min(), yys_[-i-1].max(), yys_[-i-1].mean(), yys_[-i-1].std())
    print("r", xxs_[i+1].min(), xxs_[i+1].max())
    print()

In [None]:
invback.flows[2].training

In [None]:
xxs_[0].min(), xxs_[0].max()

In [None]:
yys_[-1].min(), yys_[-1].max()