In [3]:
import os
import sys
import time
import math
import argparse
import numpy as np
import pandas as pd
from scipy import stats
# from gensim.models import KeyedVectors
import torch
import torch.nn as nn
import torch.nn.functional as F

In [4]:
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models

In [5]:
device = 'cuda' if torch.cuda.is_available else 'cpu'

In [6]:
model = models.resnet18()
model.fc = nn.Linear(512,10)#9*50)
model = model.to(device)

In [7]:
root = './data'
if not os.path.exists(root):
    os.mkdir(root)
    
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
# if not exist, download mnist dataset
train_set = MNIST(root=root, train=True, transform=trans, download=True)
test_set = MNIST(root=root, train=False, transform=trans, download=True)

batch_size = 32

train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=batch_size,
                 shuffle=True)
test_loader = torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=200,
                shuffle=True)

print('==>>> total trainning batch number: {}'.format(len(train_loader)))
print('==>>> total testing batch number: {}'.format(len(test_loader)))


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
Processing...
Done!
==>>> total trainning batch number: 1875
==>>> total testing batch number: 50


In [6]:
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

criterion = nn.CrossEntropyLoss()

for epoch in range(1):
    # trainning
    ave_loss = None
    for batch_idx, (x, target) in enumerate(train_loader):
        x = x.to(device)
        x = torch.cat((x,x,x), dim=1)
        target = target.to(device)        
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, target)
        ave_loss = loss.item() if ave_loss is None else ave_loss * 0.9 + loss.item() * 0.1
        loss.backward()
        optimizer.step()
        if (batch_idx) % 100 == 0 or (batch_idx+1) == len(train_loader):
            print('==>>> epoch: {}, batch index: {}, train loss: {:.6f}'.format(
                epoch, batch_idx+1, ave_loss))

    # testing
    correct_cnt, ave_loss_val = 0, None
    total_cnt = 0
    for batch_idx_val, (x, target) in enumerate(test_loader):
        x, target = x.to(device), target.to(device)
        x = torch.cat((x,x,x), dim=1)
        out = model(x)
        loss = criterion(out, target)
        _, pred_label = torch.max(out.data, 1)
        total_cnt += x.size(0)
        correct_cnt += (pred_label == target.data).sum()
        # smooth average
        ave_loss_val = loss.item() if ave_loss_val is None else ave_loss_val * 0.9 + loss.item() * 0.1

        if(batch_idx_val+1) % 100 == 0 or (batch_idx+1) == len(test_loader):
            print('==>>> epoch: {}, batch index: {}, test loss: {:.6f}, acc: {:.3f}'.format(
                epoch, batch_idx_val+1, ave_loss_val, correct_cnt * 1.0 / total_cnt))

==>>> epoch: 0, batch index: 1, train loss: 2.393692
==>>> epoch: 0, batch index: 101, train loss: 0.370078
==>>> epoch: 0, batch index: 201, train loss: 0.293662
==>>> epoch: 0, batch index: 301, train loss: 0.140086
==>>> epoch: 0, batch index: 401, train loss: 0.146971
==>>> epoch: 0, batch index: 501, train loss: 0.144407
==>>> epoch: 0, batch index: 601, train loss: 0.164758
==>>> epoch: 0, batch index: 701, train loss: 0.112100
==>>> epoch: 0, batch index: 801, train loss: 0.139985
==>>> epoch: 0, batch index: 901, train loss: 0.069950
==>>> epoch: 0, batch index: 1001, train loss: 0.063384
==>>> epoch: 0, batch index: 1101, train loss: 0.044818
==>>> epoch: 0, batch index: 1201, train loss: 0.096861
==>>> epoch: 0, batch index: 1301, train loss: 0.094492
==>>> epoch: 0, batch index: 1401, train loss: 0.064533
==>>> epoch: 0, batch index: 1501, train loss: 0.064103
==>>> epoch: 0, batch index: 1601, train loss: 0.066849
==>>> epoch: 0, batch index: 1701, train loss: 0.054226
==>>

In [7]:
for batch_idx, (x, target) in enumerate(test_loader):
    x, target = x.to(device), target.to(device)
    x = torch.cat((x,x,x), dim=1)
    out = model(x)
    loss = criterion(out, target)
    _, pred_label = torch.max(out.data, 1)
    total_cnt += x.data.size()[0]
    correct_cnt += (pred_label == target.data).sum()
    # smooth average
    ave_loss = loss.item() if ave_loss is None else ave_loss * 0.9 + loss.item() * 0.1
    cor_pd = pd.DataFrame.corr(pd.DataFrame(out.cpu().detach().numpy()))
    gmm_clustering(cor_pd,f'clusters/cl_mnist{batch_idx}.txt')
    if (batch_idx+1) == len(test_loader):
        print('==>>> epoch: {}, batch index: {}, test loss: {:.6f}, acc: {:.3f}'.format(
            epoch, batch_idx+1, ave_loss, correct_cnt * 1.0 / total_cnt))


Final number of leaves: 10
Size of vocabulary: 10
Number of internal nodes: 9
Mean code word length: 4.5
Final number of leaves: 10
Size of vocabulary: 10
Number of internal nodes: 9
Mean code word length: 4.5
Final number of leaves: 10
Size of vocabulary: 10
Number of internal nodes: 9
Mean code word length: 4.6
Final number of leaves: 10
Size of vocabulary: 10
Number of internal nodes: 9
Mean code word length: 4.6
Final number of leaves: 10
Size of vocabulary: 10
Number of internal nodes: 9
Mean code word length: 4.5
Final number of leaves: 10
Size of vocabulary: 10
Number of internal nodes: 9
Mean code word length: 4.5
Final number of leaves: 10
Size of vocabulary: 10
Number of internal nodes: 9
Mean code word length: 4.4
Final number of leaves: 10
Size of vocabulary: 10
Number of internal nodes: 9
Mean code word length: 4.5
Final number of leaves: 10
Size of vocabulary: 10
Number of internal nodes: 9
Mean code word length: 4.5
Final number of leaves: 10
Size of vocabulary: 10
Numbe

In [8]:
def get_label_hierarchy(path):
    tree_labels_path = {}
    tree_label_full_path = {}
    tree_paths = set()
    p2t = {}
    with open(path,'r') as f:
        for line in f:
            label, l_path = line.split(',')[:2]
            full_path = l_path.strip()
            p2t[full_path] = int(label)
            tree_label_full_path[int(label)]=[int(l) for l in list(full_path)]
            path_labels ={}
            for k in range(1,1+len(full_path)):
                tree_paths.add(full_path[:k-1])
                path_labels[full_path[:k-1]] = int(full_path[k-1])
            tree_labels_path[int(label)] = path_labels
    path_inds = {k:i for i,k in enumerate(sorted(tree_paths,key=len))}
    tree_labels_path_indexed = {
        l:{path_inds[p]:p_l for p,p_l in path_dict.items()} 
        for l, path_dict in tree_labels_path.items()
    }
    labels_hier_idx = {}
    for k, v in tree_labels_path_indexed.items():
        idx,labs = list(zip(*v.items()))
        labels_hier_idx[k] = (list(idx),list(labs))
    return labels_hier_idx, len(tree_paths), path_inds, p2t

In [9]:
labels_hier_idx_l, num_of_paths_l, path_idx_l, p2t_l = [], [], [], []
for i in range(50):
    labels_hier_idx, num_of_paths, path_idx, p2t = get_label_hierarchy(f'clusters/cl_mnist{i}.txt')
    labels_hier_idx = {k:(torch.tensor(v[0]).long().to(device),torch.tensor(v[1]).float().to(device)) for k,v in labels_hier_idx.items()}
    labels_hier_idx_l.append(labels_hier_idx)
    num_of_paths_l.append(num_of_paths)
    path_idx_l.append(path_idx)
    p2t_l.append(p2t)

In [10]:
NUM_HSMX = len(labels_hier_idx_l)
NUM_PATHS = num_of_paths

In [11]:
comb_label_idx = {}
for k in range(10):
    comb_idx = torch.cat([h_idx[k][0] + NUM_PATHS*m for m, h_idx in enumerate(labels_hier_idx_l)])
    comb_labels = torch.cat([h_idx[k][1] for h_idx in labels_hier_idx_l])
    comb_label_idx[k] = (comb_idx, comb_labels)

In [12]:
class Identity(torch.nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x

In [99]:
model.hsmx_weights = nn.Parameter(torch.randn(512, NUM_PATHS*NUM_HSMX+1,device=device)*0.025)
model.hsmx_bias = nn.Parameter(torch.randn(NUM_PATHS*NUM_HSMX+1,device=device)*0.025)
def restore_zero_col():
    model.hsmx_weights[:,-1] = 0
    model.hsmx_bias[-1] = 0
restore_zero_col()

In [None]:
model = models.resnet18()
model.fc = Identity()
model = model.to(device)

In [95]:
used_idx = torch.cat((comb_label_idx[0][0],comb_label_idx[3][0])).unique(sorted=True)
t_weights = model.hsmx_weights[:,used_idx]
t_bias = model.hsmx_bias[used_idx]

In [49]:
torch.gather(model.hsmx_weights, dim=1, torch.Tensor([0,2]))

Parameter containing:
tensor([[ 0.0487,  0.0009,  0.0250,  ...,  0.0268,  0.0272, -0.0083],
        [ 0.0119, -0.0750, -0.0222,  ..., -0.0240, -0.0206,  0.0403],
        [-0.0368,  0.0087,  0.0242,  ...,  0.0292,  0.0190,  0.0295],
        ...,
        [ 0.0205,  0.0034, -0.0059,  ..., -0.0117,  0.0062, -0.0273],
        [ 0.0109, -0.0225, -0.0311,  ..., -0.0192,  0.0205, -0.0060],
        [-0.0370,  0.0080,  0.0112,  ...,  0.0161,  0.0154, -0.0149]],
       device='cuda:0', requires_grad=True)

In [77]:
gather_ind = torch.Tensor([[1,2]]).long().expand(512,2)

In [101]:
tensor = torch.randint(10,[3,4,5], dtype=torch.long)
values = torch.arange(10, dtype=torch.long)
result = torch.nonzero(tensor[..., None] == values)
print(result.shape)

torch.Size([60, 4])


In [109]:
tensor

tensor([[[4, 8, 2, 1, 6],
         [3, 5, 2, 1, 4],
         [0, 9, 0, 9, 5],
         [1, 9, 6, 3, 2]],

        [[5, 6, 6, 2, 2],
         [8, 9, 2, 1, 2],
         [3, 3, 6, 2, 2],
         [3, 2, 2, 0, 5]],

        [[7, 5, 8, 9, 0],
         [4, 6, 8, 2, 8],
         [5, 5, 8, 6, 3],
         [8, 8, 8, 4, 3]]])

In [111]:
result[:,-1]

tensor([4, 8, 2, 1, 6, 3, 5, 2, 1, 4, 0, 9, 0, 9, 5, 1, 9, 6, 3, 2, 5, 6, 6, 2,
        2, 8, 9, 2, 1, 2, 3, 3, 6, 2, 2, 3, 2, 2, 0, 5, 7, 5, 8, 9, 0, 4, 6, 8,
        2, 8, 5, 5, 8, 6, 3, 8, 8, 8, 4, 3])

In [103]:
tensor

tensor([[[4, 8, 2, 1, 6],
         [3, 5, 2, 1, 4],
         [0, 9, 0, 9, 5],
         [1, 9, 6, 3, 2]],

        [[5, 6, 6, 2, 2],
         [8, 9, 2, 1, 2],
         [3, 3, 6, 2, 2],
         [3, 2, 2, 0, 5]],

        [[7, 5, 8, 9, 0],
         [4, 6, 8, 2, 8],
         [5, 5, 8, 6, 3],
         [8, 8, 8, 4, 3]]])

In [102]:
result

tensor([[0, 0, 0, 4],
        [0, 0, 1, 8],
        [0, 0, 2, 2],
        [0, 0, 3, 1],
        [0, 0, 4, 6],
        [0, 1, 0, 3],
        [0, 1, 1, 5],
        [0, 1, 2, 2],
        [0, 1, 3, 1],
        [0, 1, 4, 4],
        [0, 2, 0, 0],
        [0, 2, 1, 9],
        [0, 2, 2, 0],
        [0, 2, 3, 9],
        [0, 2, 4, 5],
        [0, 3, 0, 1],
        [0, 3, 1, 9],
        [0, 3, 2, 6],
        [0, 3, 3, 3],
        [0, 3, 4, 2],
        [1, 0, 0, 5],
        [1, 0, 1, 6],
        [1, 0, 2, 6],
        [1, 0, 3, 2],
        [1, 0, 4, 2],
        [1, 1, 0, 8],
        [1, 1, 1, 9],
        [1, 1, 2, 2],
        [1, 1, 3, 1],
        [1, 1, 4, 2],
        [1, 2, 0, 3],
        [1, 2, 1, 3],
        [1, 2, 2, 6],
        [1, 2, 3, 2],
        [1, 2, 4, 2],
        [1, 3, 0, 3],
        [1, 3, 1, 2],
        [1, 3, 2, 2],
        [1, 3, 3, 0],
        [1, 3, 4, 5],
        [2, 0, 0, 7],
        [2, 0, 1, 5],
        [2, 0, 2, 8],
        [2, 0, 3, 9],
        [2, 0, 4, 0],
        [2

In [78]:
torch.gather(model.hsmx_weights,dim=1,index=gather_ind.to(device))

tensor([[ 0.0009,  0.0250],
        [-0.0750, -0.0222],
        [ 0.0087,  0.0242],
        ...,
        [ 0.0034, -0.0059],
        [-0.0225, -0.0311],
        [ 0.0080,  0.0112]], device='cuda:0', grad_fn=<GatherBackward>)

In [68]:
torch.cat([torch.arange(512)[:,None], torch.Tensor([[1,2]]).long().expand(512,2)], dim=-1)

tensor([[  0,   1,   2],
        [  1,   1,   2],
        [  2,   1,   2],
        ...,
        [509,   1,   2],
        [510,   1,   2],
        [511,   1,   2]])

In [53]:
model.hsmx_weights[:,[0,1,2]]

tensor([[ 0.0487,  0.0009,  0.0250],
        [ 0.0119, -0.0750, -0.0222],
        [-0.0368,  0.0087,  0.0242],
        ...,
        [ 0.0205,  0.0034, -0.0059],
        [ 0.0109, -0.0225, -0.0311],
        [-0.0370,  0.0080,  0.0112]], device='cuda:0', grad_fn=<IndexBackward>)

In [None]:
def pred_class_hsfmx(pred, path_idx, p2t, start_ind=0):
    current_node=0
    current_path = []
    cur_node_path_idx = [0]
    while True:     
        next_path_pred = pred[start_ind+cur_node_path_idx[-1]]
        current_path.append('1' if next_path_pred.item() >= 0 else '0')
        new_path = ''.join(current_path)
        if new_path in p2t:
            return p2t[new_path]
        cur_node_path_idx.append(path_idx[new_path])

def pred_batch(output, path_idx_l, p2t_l):
    return torch.Tensor([
        pred_class_hsfmx(row, path_idx, p2t, k*NUM_PATHS) 
        for row in output
        for k, (path_idx, p2t) in enumerate(zip(path_idx_l,p2t_l))
    ]).long().to(device).reshape(output.size(0), NUM_HSMX)

In [32]:
list(model.parameters())[-1].size()

torch.Size([450])

In [None]:
optimizer = optim.SGD(model.parameters(), lr=0.003, momentum=0.9)
criterion_hsmx = nn.BCEWithLogitsLoss()
for epoch in range(10):
    # trainning
    ave_loss = None
    for batch_idx, (x, target) in enumerate(train_loader):
        x = x.to(device)
        x = torch.cat((x,x,x), dim=1)
        target = target.to(device)        
        optimizer.zero_grad()
        out = model(x)
        y_hsmx_idx = torch.cat([row*NUM_PATHS*NUM_HSMX + comb_label_idx[l][0]  for row, l in enumerate(target.tolist())])
        y_hsmx_labels = torch.cat([comb_label_idx[l][1]  for row, l in enumerate(target.tolist())])
        out_hsmx =  torch.gather(out.flatten(), 0, y_hsmx_idx)
        loss  = criterion_hsmx(out_hsmx,y_hsmx_labels)

        ave_loss = loss.item() if ave_loss is None else ave_loss * 0.9 + loss.item() * 0.1
        loss.backward()
        optimizer.step()
        if (batch_idx) % 100 == 0 or (batch_idx+1) == len(train_loader):
            print('==>>> epoch: {}, batch index: {}, train loss: {:.6f}'.format(
                epoch, batch_idx+1, ave_loss))
            
    # testing
    correct_cnt, ave_loss_val = 0, None
    total_cnt = 0
    for batch_idx, (x, target) in enumerate(test_loader):
        x, target = x.to(device), target.to(device)
        x = torch.cat((x,x,x), dim=1)
        out = model(x)
        y_hsmx_idx = torch.cat([row*num_of_paths + comb_label_idx[l][0]  for row, l in enumerate(target.tolist())])
        y_hsmx_labels = torch.cat([comb_label_idx[l][1]  for row, l in enumerate(target.tolist())])
        out_hsmx =  torch.gather(out.flatten(), 0, y_hsmx_idx)
        loss  = criterion_hsmx(out_hsmx,y_hsmx_labels)

        pred_label = pred_batch(out, path_idx_l, p2t_l)
        pred_label_mode = pred_label.mode(dim=1)[0]
        total_cnt += x.size(0)
        correct_cnt += (pred_label_mode == target).sum()
        # smooth average
        ave_loss_val = loss.item() if ave_loss_val is None else ave_loss_val * 0.9 + loss.item() * 0.1

        if(batch_idx+1) % 100 == 0 or (batch_idx+1) == len(test_loader):
            print('==>>> epoch: {}, batch index: {}, test loss: {:.6f}, acc: {:.3f}'.format(
                epoch, batch_idx+1, ave_loss_val, correct_cnt * 1.0 / total_cnt))

In [27]:
# testing
correct_cnt, ave_loss_val = 0, None
total_cnt = 0
for batch_idx, (x, target) in enumerate(test_loader):
    x, target = x.to(device), target.to(device)
    x = torch.cat((x,x,x), dim=1)
    out = model(x)
    y_hsmx_idx = torch.cat([row*num_of_paths + comb_label_idx[l][0]  for row, l in enumerate(target.tolist())])
    y_hsmx_labels = torch.cat([comb_label_idx[l][1]  for row, l in enumerate(target.tolist())])
    out_hsmx =  torch.gather(out.flatten(), 0, y_hsmx_idx)
    loss  = criterion_hsmx(out_hsmx,y_hsmx_labels)

    pred_label = pred_batch(out, path_idx_l, p2t_l)
    pred_label_mode = pred_label.mode(dim=1)[0]
    total_cnt += x.size(0)
    correct_cnt += (pred_label_mode == target).sum()
    # smooth average
    ave_loss_val = loss.item() if ave_loss_val is None else ave_loss_val * 0.9 + loss.item() * 0.1

    if(batch_idx+1) % 100 == 0 or (batch_idx+1) == len(test_loader):
        print('==>>> epoch: {}, batch index: {}, test loss: {:.6f}, acc: {:.3f}'.format(
            epoch, batch_idx+1, ave_loss_val, correct_cnt * 1.0 / total_cnt))

==>>> epoch: 0, batch index: 50, test loss: 0.759030, acc: 0.597


In [127]:
batch_idx

0

In [21]:
pred_label = pred_batch(out, path_idx_l, p2t_l)

In [22]:
for i, row in enumerate(pred_label):
    print(i, row.unique().size())

0 torch.Size([10])
1 torch.Size([10])
2 torch.Size([10])
3 torch.Size([9])
4 torch.Size([10])
5 torch.Size([8])
6 torch.Size([10])
7 torch.Size([9])
8 torch.Size([10])
9 torch.Size([10])
10 torch.Size([10])
11 torch.Size([10])
12 torch.Size([10])
13 torch.Size([10])
14 torch.Size([10])
15 torch.Size([10])
16 torch.Size([10])
17 torch.Size([10])
18 torch.Size([10])
19 torch.Size([10])
20 torch.Size([9])
21 torch.Size([9])
22 torch.Size([10])
23 torch.Size([10])
24 torch.Size([10])
25 torch.Size([10])
26 torch.Size([10])
27 torch.Size([10])
28 torch.Size([10])
29 torch.Size([10])
30 torch.Size([10])
31 torch.Size([10])
32 torch.Size([10])
33 torch.Size([10])
34 torch.Size([10])
35 torch.Size([9])
36 torch.Size([10])
37 torch.Size([10])
38 torch.Size([10])
39 torch.Size([10])
40 torch.Size([10])
41 torch.Size([10])
42 torch.Size([10])
43 torch.Size([10])
44 torch.Size([10])
45 torch.Size([10])
46 torch.Size([10])
47 torch.Size([9])
48 torch.Size([10])
49 torch.Size([10])
50 torch.Size([10