# Getting Started with simple-hierarchy-pytorch

In [1]:
!pip install simple-hierarchy-pytorch

Collecting simple-hierarchy-pytorch
  Using cached simple_hierarchy_pytorch-0.0.1-py3-none-any.whl (8.9 kB)
Installing collected packages: simple-hierarchy-pytorch
Successfully installed simple-hierarchy-pytorch-0.0.1


In [4]:
from simple_hierarchy.hierarchal_model import HierarchalModel
import torch
import torch.nn as nn

In [5]:
hierarchy = {
    ("A", 2) : [("B", 5), ("C", 7)],
    ("H", 2) : [("A", 2), ("K", 7), ("L", 10)]
}
# first two layers are base model
# last two are distinct per class
model = HierarchalModel(model=nn.ModuleList([nn.Linear(10, 10) for i in range(4)]), k=2, hierarchy=hierarchy, size=(10,10,10))
input = torch.rand((10,10))
out = model(input)

In [6]:
# the model's tree
model.tree

H 2 [A 2 [B 5 [], C 7 []], K 7 [], L 10 []]

In [7]:
# base layers
model.base_model

Sequential(
  (0): Linear(in_features=10, out_features=10, bias=True)
  (1): Linear(in_features=10, out_features=10, bias=True)
)

In [8]:
# the layers that are distinct per class
# the additional layers are to link together differing output sizes to the provided layers (two additional layers per class)
# in a later version these may be customizable to the layers you want (should you want to proivde distinct aspects of non-linear layers; currently requires conversion into linear layers) 
model.last_layers

ModuleDict(
  (('H', 2)): Sequential(
    (0): Linear(in_features=10, out_features=10, bias=True)
    (1): Linear(in_features=10, out_features=10, bias=True)
    (2): Linear(in_features=10, out_features=10, bias=True)
    (3): Linear(in_features=10, out_features=2, bias=True)
  )
  (('A', 2)): Sequential(
    (0): Linear(in_features=12, out_features=10, bias=True)
    (1): Linear(in_features=10, out_features=10, bias=True)
    (2): Linear(in_features=10, out_features=10, bias=True)
    (3): Linear(in_features=10, out_features=2, bias=True)
  )
  (('B', 5)): Sequential(
    (0): Linear(in_features=12, out_features=10, bias=True)
    (1): Linear(in_features=10, out_features=10, bias=True)
    (2): Linear(in_features=10, out_features=10, bias=True)
    (3): Linear(in_features=10, out_features=5, bias=True)
  )
  (('C', 7)): Sequential(
    (0): Linear(in_features=12, out_features=10, bias=True)
    (1): Linear(in_features=10, out_features=10, bias=True)
    (2): Linear(in_features=10, out

In [9]:
# output order and shape (order is defined through parameter output_order)
for o, a in zip(out, model.output_order):
  print(a, " : ", o.shape)

('H', 2)  :  torch.Size([10, 2])
('A', 2)  :  torch.Size([10, 2])
('B', 5)  :  torch.Size([10, 5])
('C', 7)  :  torch.Size([10, 7])
('K', 7)  :  torch.Size([10, 7])
('L', 10)  :  torch.Size([10, 10])


# Class Hierarchy  Models

In [None]:
import torch
import torch
import torch.nn as nn
import torch.optim as optim 
from typing import List, Dict, Tuple, Optional
from itertools import chain
from torch.autograd import Variable
import torchvision.datasets as datasets
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms

In [None]:
class Node(object):
  def __init__(self, name, n_classes, parent):
      self.n_classes = n_classes
      self.name = name
      self.children = []
      self.parent = parent

  def add_child(self, child):
      self.children.append(child)
  def __repr__(self) -> str:
    return str(self.name) + " " + str(self.n_classes) + " " + str(self.children)
  def get_tuple(self):
    return (self.name, self.n_classes)
  def __iter__(self):
    isingle = lambda x : (yield x)
    return chain(*([isingle(self)] + list(map(iter, self.children))))

class Tree(object):
  def __init__(self, root : Node):
      self.root = root
  def __repr__(self):
    return self.root.__repr__()
  def __iter__(self):
    return iter(self.root)


In [None]:
hierarchy = {
    ("A", 2) : [("B", 5), ("C", 7)],
    ("H", 2) : [("A", 2), ("K", 7), ("L", 10)]
}

def to_tree(hierarchy, root_node):
  root = root_node.get_tuple()
  for i, (node, children) in list(enumerate(hierarchy.items())):
    if root == node:
      for c in children:
        child = Node(*c, root_node)
        root_node.add_child(child)
        to_tree(hierarchy, child)
  if root in hierarchy:
    hierarchy.pop(root)

def hierarchy_to_tree(hierarchy : Dict[Tuple, Tuple]):
  all_children = list()
  for i, ((parent, n_classes1), children) in enumerate(hierarchy.items()):
    all_children.extend(children)
  found_root = False
  root = None
  for i, (node, children) in enumerate(hierarchy.items()):
    if node not in all_children:

      root = node
      if found_root:
        raise ValueError("Invalid hierarchy tree.")
      found_root = True
  root_node = Node(root[0], root[1], None)
  hier = hierarchy.copy()
  to_tree(hier, root_node)
  return Tree(root_node)

In [None]:
tree = hierarchy_to_tree(hierarchy)
tree

H 2 [A 2 [B 5 [], C 7 []], K 7 [], L 10 []]

In [None]:
root = Node('A', 2, None)
child1 = Node('B', 3, root)
child2 = Node('C', 5, root)
child_of_child = Node('D', 3, child1)
root.add_child(child1)
root.add_child(child2)
child1.add_child(child_of_child)
tree = Tree(root)
print(tree)

A 2 [B 3 [D 3 []], C 5 []]


In [None]:
for t in tree:
  print(t)

H 2 [A 2 [B 5 [], C 7 []], K 7 [], L 10 []]
A 2 [B 5 [], C 7 []]
B 5 []
C 7 []
K 7 []
L 10 []


In [None]:
class HierarchalModel(torch.nn.Module):
  def __init__(self, hierarchy : Dict[Tuple, Tuple], size : int, 
               output_order: Optional[List] = None, base_model: Optional = None, model: Optional[nn.ModuleList] = None, 
               k: Optional[int] = 0, dim_to_concat: Optional[int] = None):
    super(HierarchalModel, self).__init__()
    if base_model:
      self.base_model = base_model
    else:
      self.base_model = nn.Sequential(*model[0:len(model) - k])
    self.last_layers = dict()
    self.tree = hierarchy_to_tree(hierarchy)
    self.output_order = output_order
    if dim_to_concat:
      self.dim_to_concat = dim_to_concat
    else:
      self.dim_to_concat = 1
    for node in self.tree:
      if model:
        layer1 = model[len(model) - k: len(model)]
      else:
        layer1 = nn.ModuleList()
      if node.parent:
        n_classes1 = node.parent.n_classes
      else:
        n_classes1 = 0
      n_classes2 = node.n_classes
      
      layers = nn.ModuleList()
      layers.append(torch.nn.Linear(size[0] + n_classes1, size[1]))
      layers.extend(layer1)
      layers.append(torch.nn.Linear(size[2], n_classes2))
      self.last_layers[str(node.get_tuple())] = nn.Sequential(*layers)
    self.last_layers = nn.ModuleDict(self.last_layers)
  def forward(self, x):
    x = self.base_model(x)
    # enumerate over a tree concating parents output into children outs
    output_temp = dict()
    for node in self.tree:
      if node.parent:
        parent_out = output_temp[node.parent.get_tuple()]

        end_input = torch.cat((parent_out,x), self.dim_to_concat)
        output_temp[node.get_tuple()] = self.last_layers[str(node.get_tuple())](end_input)
      else:
        output_temp[node.get_tuple()] = self.last_layers[str(node.get_tuple())](x)
    outputs = list()
    if not self.output_order:
      self.output_order = output_temp.keys()
    for o in self.output_order:
      outputs.append(output_temp[o])
    return tuple(outputs)


  def hierarchy_to_tree(self, hierarchy : Dict[Tuple, Tuple]):
    all_children = list()
    for i, ((parent, n_classes1), children) in enumerate(hierarchy.items()):
      all_children.extend(children)
    found_root = False
    root = None
    for i, (node, children) in enumerate(hierarchy.items()):
      if node not in all_children:
        root = node
        if found_root:
          raise ValueError("Invalid hierarchy tree.")
        found_root = True
    root_node = Node(root[0], root[1], None)
    hier = hierarchy.copy()
    to_tree(hier, root_node)
    return Tree(root_node)
    


## Testing

### Basic Testing of Singular Input

In [None]:
model = HierarchalModel(model=nn.ModuleList([nn.Linear(10, 10) for i in range(2)]), k=1, hierarchy=hierarchy, size=(10,10,10))
input = torch.rand((10,10))
out = model(input)
for o, a in zip(out, model.output_order):
  print(a, " : ", o.shape)


('H', 2)  :  torch.Size([10, 2])
('A', 2)  :  torch.Size([10, 2])
('B', 5)  :  torch.Size([10, 5])
('C', 7)  :  torch.Size([10, 7])
('K', 7)  :  torch.Size([10, 7])
('L', 10)  :  torch.Size([10, 10])


In [None]:
model.tree

H 2 [A 2 [B 5 [], C 7 []], K 7 [], L 10 []]

In [None]:
model.tree

H 2 [A 2 [B 5 [], C 7 []], K 7 [], L 10 []]

## Example
Let use consider the case of dataset with data from three different cities.
In each city are 5 county and in each county there are 7 districts. Therefore we define a simple heirarchy $A$ (city) -> $B$ (county) -> $C$(district) where $A$ is of size $n, 3$, $B$ of $n, 5$, and C of $n, 7$. The following example will illustrate this example in code.

In [None]:
from torch.utils.data import Dataset, DataLoader
class RegionDataset(Dataset):
  def __init__(self, length=100, transform=None):
        """
        Example hierarchal dataset using random data. Data is of size 3 x 36 x 36 mimicking an image dataset.
        Args:
            length (int, optional): Size of dataset.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.length = length
        self.transform = transform
        self.data = torch.rand(length, 3, 36, 36)
        self.labelA = torch.randint(0, 2, (length,))
        self.labelB = torch.randint(0, 5, (length,))
        self.labelC = torch.randint(0, 7, (length,))
  def __len__(self):
      return self.length

  def __getitem__(self, idx):
      if torch.is_tensor(idx):
        idx = idx.tolist()
      sample = self.data[idx]
      labelA = self.labelA[idx]
      labelB = self.labelB[idx]
      labelC = self.labelC[idx]
      if self.transform:
          sample = self.transform(sample)

      return sample, labelA, labelB, labelC

In [None]:
dataset = RegionDataset(length=1000)
percent_train = 0.8
train_size = int(0.8*len(dataset))
val_size = int(len(dataset) - train_size)
trainset_t, valset_t = torch.utils.data.random_split(dataset, [train_size, val_size])
example_dataset = {'train': trainset_t, 
                  'val' : valset_t}
dataloaders = {x: torch.utils.data.DataLoader(example_dataset[x], batch_size=batch_size,
                                             shuffle=True, num_workers=0)
              for x in ['train', 'val']}
dataset_sizes = {x: len(example_dataset[x]) for x in ['train', 'val']}

In [None]:
def train_model(model, criterion, optimizer, scheduler, nepochs, dataset_sizes):
    start_time = time.time()
    best_model_val = copy.deepcopy(model.state_dict())
    best_val_acc = 0.0
    for epoch in range(nepochs):
        print('Epoch {}/{}'.format(epoch, nepochs - 1))
        print('-' * 10)
        # Each epoch has a training and validation phase
        for ph in ['train', 'val']:
            running_loss = 0.0
            running_correctsA = 0
            running_correctsB = 0
            running_correctsC = 0
            if ph == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode
            for inputs, labelA, labelB, labelC in dataloaders[ph]:

                # zero per epoch
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(ph == 'train'):
                    outputA, outputB, outputC = model(inputs)

                    _, predsA = torch.max(outputA, 1)
                    _, predsB = torch.max(outputB, 1)
                    _, predsC = torch.max(outputC, 1)
                    loss = criterion(outputA, labelA) + criterion(outputB, labelB) + criterion(outputC, labelC)

                    # backward + optimize only if in training phase
                    if ph == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += float(loss.item() * inputs.size(0))


                running_correctsA += torch.sum(predsA == labelA.data)
                running_correctsB += torch.sum(predsB == labelB.data)
                running_correctsC += torch.sum(predsC == labelC.data)
                
            if ph == 'train':
                scheduler.step()

            e_loss = running_loss / dataset_sizes[ph]
            epoch_acc = (running_correctsA.double() + running_correctsB.double() + running_correctsC.double())/  (3 * dataset_sizes[ph])
            print("Accuracy A {:.4f}, Accuracy B {:.4f}, Accuracy C {:.4f}".format(running_correctsA.double()/dataset_sizes[ph], 
                                                                                   running_correctsB.double()/dataset_sizes[ph], 
                                                                                   running_correctsC.double()/dataset_sizes[ph]))
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                ph, e_loss, epoch_acc))
            # save best val accuracy
            if ph == 'val' and epoch_acc > best_val_acc:
                best_val_acc = epoch_acc
                best_model_val = copy.deepcopy(model.state_dict())
        print()

    time_elapsed = time.time() - start_time
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_val_acc))

    # load best model weights
    model.load_state_dict(best_model_val)
    return model

In [None]:
from torch.optim import lr_scheduler
from torch.nn import Linear, ReLU, CrossEntropyLoss, Sequential, Conv2d, MaxPool2d, Module, Softmax, BatchNorm2d, Dropout

nepochs = 20
nclasses = 20
lr = 0.001
model_base = nn.Sequential(
  nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5), 
  nn.ReLU(), 
  nn.MaxPool2d(kernel_size=2, stride=2), 
  nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5), 
  nn.ReLU(), 
  nn.MaxPool2d(kernel_size=2, stride=2), 
  nn.Flatten(start_dim=1), 
  nn.Linear(in_features=576, out_features=120), 
  nn.ReLU(), 
  nn.Linear(in_features=120, out_features=84), 
  nn.ReLU()
)

hierarchy = {
    ('A', 2) : [('B', 5)],
    ('B', 5) : [('C', 7)]
}
model = HierarchalModel(hierarchy, (84, 32, 32),base_model=model_base, dim_to_concat=1)
criterion = nn.CrossEntropyLoss()
optimizer_ft = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.1)

model_fully_trained = train_model(model, criterion, optimizer_ft, exp_lr_scheduler,
                       nepochs, dataset_sizes)

Epoch 0/19
----------
Accuracy A 0.5262, Accuracy B 0.1875, Accuracy C 0.1275
train Loss: 4.2603 Acc: 0.2804
Accuracy A 0.5100, Accuracy B 0.1800, Accuracy C 0.1400
val Loss: 4.2598 Acc: 0.2767

Epoch 1/19
----------
Accuracy A 0.5238, Accuracy B 0.1875, Accuracy C 0.1275
train Loss: 4.2592 Acc: 0.2796
Accuracy A 0.5100, Accuracy B 0.1800, Accuracy C 0.1400
val Loss: 4.2585 Acc: 0.2767

Epoch 2/19
----------
Accuracy A 0.5238, Accuracy B 0.1875, Accuracy C 0.1275
train Loss: 4.2580 Acc: 0.2796
Accuracy A 0.5100, Accuracy B 0.1800, Accuracy C 0.1400
val Loss: 4.2573 Acc: 0.2767

Epoch 3/19
----------
Accuracy A 0.5238, Accuracy B 0.1875, Accuracy C 0.1275
train Loss: 4.2567 Acc: 0.2796
Accuracy A 0.5100, Accuracy B 0.1800, Accuracy C 0.1400
val Loss: 4.2561 Acc: 0.2767

Epoch 4/19
----------
Accuracy A 0.5238, Accuracy B 0.1875, Accuracy C 0.1275
train Loss: 4.2553 Acc: 0.2796
Accuracy A 0.5100, Accuracy B 0.1800, Accuracy C 0.1400
val Loss: 4.2552 Acc: 0.2767

Epoch 5/19
----------
Acc

Accuracy is not great but the concept is there and works. The data we are using does not make that much sense (random numbers representing an integer), so the network struggles to find connections between random labels and random "images". This is unsurpising and to be expected. The example merely illusrate how to use this model.

## Model Graphs

In [None]:
from torch.utils.tensorboard import SummaryWriter
model = HierarchalModel(model=nn.ModuleList([nn.Linear(10, 10) for i in range(2)]), k=1, hierarchy=hierarchy, size=(10,10,10))

writer = SummaryWriter('runs/model_graph')

writer.add_graph(model, input)
writer.close()