In [13]:
import os
import random
import glob
import pandas as pd
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision.datasets import MNIST
import torchvision.transforms as transforms
import torchvision.models as models

In [14]:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(0)
torch.manual_seed(0)
np.random.seed(0)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [15]:
mnist_data_path = './MNIST_data'
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
# if not exist, download mnist dataset
train_set = MNIST(root=mnist_data_path, train=True, transform=trans, download=True)
test_set = MNIST(root=mnist_data_path, train=False, transform=trans, download=True)

batch_size = 32

train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=batch_size,
                 shuffle=True)
test_loader = torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=200,
                shuffle=True)

print('==>>> total trainning batch number: {}'.format(len(train_loader)))
print('==>>> total testing batch number: {}'.format(len(test_loader)))

==>>> total trainning batch number: 1875
==>>> total testing batch number: 50


In [16]:
def get_label_hierarchy(path):
    tree_labels_path = {}
    tree_label_full_path = {}
    tree_paths = set()
    p2t = {}
    with open(path,'r') as f:
        for line in f:
            label, l_path = line.split(',')[:2]
            full_path = l_path.strip()
            p2t[full_path] = int(label)
            tree_label_full_path[int(label)]=[int(l) for l in list(full_path)]
            path_labels ={}
            for k in range(1,1+len(full_path)):
                tree_paths.add(full_path[:k-1])
                path_labels[full_path[:k-1]] = int(full_path[k-1])
            tree_labels_path[int(label)] = path_labels
    path_inds = {k:i for i,k in enumerate(sorted(tree_paths,key=len))}
    tree_labels_path_indexed = {
        l:{path_inds[p]:p_l for p,p_l in path_dict.items()} 
        for l, path_dict in tree_labels_path.items()
    }
    labels_hier_idx = {}
    for k, v in tree_labels_path_indexed.items():
        idx,labs = list(zip(*v.items()))
        labels_hier_idx[k] = (list(idx),list(labs))
    return labels_hier_idx, len(tree_paths), path_inds, p2t


In [31]:
class HierarchicalSoftmaxEnsemble(nn.Module):
    def __init__(self, input_dim, trees_path, device=None):
        super().__init__()
        self.num_paths = None
        self.path_indices = []
        self.path2label =  []
        self.num_hsfmx = 0
        self.labels2path_labels = []
        self.labels2path_labels_combined = {}
        self.trees_path = trees_path
        if device is None:
            self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        else:
            self.device = device
        
        self.read_trees(trees_path)
        self.linear = nn.Linear(input_dim, self.num_paths * self.num_hsfmx)
        self.linear.to(device)
        
    def read_trees(self, trees_path):
        for path in glob.glob(trees_path):
            labels_hier_idx, num_of_paths, path_idx, p2t = get_label_hierarchy(path)
            labels_hier_idx = {
                k:(torch.tensor(v[0]).long().to(device),
                   torch.tensor(v[1]).float().to(device))
                for k,v in labels_hier_idx.items()
            }
            self.labels2path_labels.append(labels_hier_idx)
            if self.num_paths is None:
                self.num_paths = num_of_paths
            else:
                assert self.num_paths == num_of_paths
            self.path_indices.append(path_idx)
            self.path2label.append(p2t)
        self.num_hsfmx = len(self.labels2path_labels)
        
        self.labels2path_labels_combined = {}
        
        for k in range(len(self.labels2path_labels[0])):
            comb_idx = torch.cat([
                h_idx[k][0] + self.num_paths*m 
                for m, h_idx in enumerate(self.labels2path_labels)])
            comb_labels = torch.cat([h_idx[k][1] for h_idx in self.labels2path_labels])
            self.labels2path_labels_combined[k] = (comb_idx, comb_labels)

    def to(self, device):
        self.device = device
        for i in range(len(self.labels2path_labels)):
            self.labels2path_labels[i] = {
                k:(torch.tensor(v[0]).long().to(self.device),
                   torch.tensor(v[1]).float().to(self.device))
                for k, v in self.labels2path_labels[i].items()
            }

    def pred_label_single_hsfmx(self, pred, path_idx, p2t, start_ind=0):
        current_node=0
        current_path = []
        cur_node_path_idx = [0]
        while True:     
            next_path_pred = pred[start_ind+cur_node_path_idx[-1]]
            current_path.append('1' if next_path_pred.item() >= 0 else '0')
            new_path = ''.join(current_path)
            if new_path in p2t:
                return p2t[new_path]
            cur_node_path_idx.append(path_idx[new_path])

    def get_labels(self, output, get_mode=True):
        pred = torch.Tensor([
            self.pred_label_single_hsfmx(row, path_idx, p2t, k * self.num_paths) 
            for row in output
            for k, (path_idx, p2t) in enumerate(zip(self.path_indices, self.path2label))
        ]).long().to(self.device).reshape(output.size(0), self.num_hsfmx)
        return pred.mode(dim=1)[0] if get_mode else pred
    
    def forward(self, x, target=None, collect_paths=True, pred_labels=False, pred_labels__get_mode=True):
        output =  self.linear(x)
        if not collect_paths:
            return output

        y_hsfmx_idx = torch.cat([
            row * self.num_paths * self.num_hsfmx + self.labels2path_labels_combined[l][0] 
            for row, l in enumerate(target.tolist())])
        target_hsfmx = torch.cat([
            self.labels2path_labels_combined[l][1] 
            for row, l in enumerate(target.tolist())])
        output_hsfmx =  torch.gather(output.flatten(), 0, y_hsfmx_idx)
        
        if pred_labels:
            return output_hsfmx, target_hsfmx, self.get_labels(output, pred_labels__get_mode)
        return output_hsfmx, target_hsfmx


In [32]:
class Identity(torch.nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x

In [33]:
model = models.resnet18()
model.fc = Identity()
model = model.to(device)

In [34]:
mnist_hsfmx = HierarchicalSoftmaxEnsemble(input_dim=512, trees_path='clusters/*', device=device)
optimizer = optim.Adam(list(model.parameters()) + list(mnist_hsfmx.parameters()), lr=0.001)
criterion_hsmx = nn.BCEWithLogitsLoss()
for epoch in range(3):
    # trainning
    ave_loss = None
    for batch_idx, (x, target) in enumerate(train_loader):
        x = x.to(device)
        x = torch.cat((x,x,x), dim=1)
        target = target.to(device)        
        optimizer.zero_grad()
        output_fc = model(x)
        output_hsfmx, target_hsfmx = mnist_hsfmx(output_fc, target)
        loss  = criterion_hsmx(output_hsfmx, target_hsfmx)

        ave_loss = loss.item() if ave_loss is None else ave_loss * 0.9 + loss.item() * 0.1
        loss.backward()
        optimizer.step()
        if (batch_idx) % 100 == 0 or (batch_idx+1) == len(train_loader):
            print('==>>> epoch: {}, batch index: {}, train loss: {:.6f}'.format(
                epoch, batch_idx+1, ave_loss))
            
    # testing
    correct_cnt, ave_loss_val = 0, None
    total_cnt = 0
    for batch_idx, (x, target) in enumerate(test_loader):
        x, target = x.to(device), target.to(device)
        x = torch.cat((x,x,x), dim=1)
        out = model(x)
        output_fc = model(x)
        output_hsfmx, target_hsfmx, preds = mnist_hsfmx(output_fc, target, pred_labels=True)
        loss  = criterion_hsmx(output_hsfmx, target_hsfmx)

        total_cnt += x.size(0)
        correct_cnt += (preds == target).sum()
        # smooth average
        ave_loss_val = loss.item() if ave_loss_val is None else ave_loss_val * 0.9 + loss.item() * 0.1

        if(batch_idx+1) % 100 == 0 or (batch_idx+1) == len(test_loader):
            print('==>>> epoch: {}, batch index: {}, test loss: {:.6f}, acc: {:.3f}'.format(
                epoch, batch_idx+1, ave_loss_val, correct_cnt * 1.0 / total_cnt))


==>>> epoch: 0, batch index: 1, train loss: 0.771188
==>>> epoch: 0, batch index: 101, train loss: 0.073097
==>>> epoch: 0, batch index: 201, train loss: 0.064505
==>>> epoch: 0, batch index: 301, train loss: 0.041280
==>>> epoch: 0, batch index: 401, train loss: 0.050774
==>>> epoch: 0, batch index: 501, train loss: 0.038083
==>>> epoch: 0, batch index: 601, train loss: 0.032977
==>>> epoch: 0, batch index: 701, train loss: 0.032610
==>>> epoch: 0, batch index: 801, train loss: 0.033053
==>>> epoch: 0, batch index: 901, train loss: 0.027311
==>>> epoch: 0, batch index: 1001, train loss: 0.033389
==>>> epoch: 0, batch index: 1101, train loss: 0.025873
==>>> epoch: 0, batch index: 1201, train loss: 0.023599
==>>> epoch: 0, batch index: 1301, train loss: 0.021803
==>>> epoch: 0, batch index: 1401, train loss: 0.021761
==>>> epoch: 0, batch index: 1501, train loss: 0.027197
==>>> epoch: 0, batch index: 1601, train loss: 0.015227
==>>> epoch: 0, batch index: 1701, train loss: 0.020455
==>>