In [2]:
import time
from functools import wraps

def execution_timer(func):
  @wraps(func)
  def wrapper(*args, **kwargs):
    start_time = time.perf_counter()
    result = func(*args, **kwargs)
    end_time = time.perf_counter()
    elapsed_time = end_time - start_time
    print(f"Function {func.__name__} executed in {elapsed_time:.4f} seconds")
    return result
  return wrapper

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
import torch.optim as optim

class PatchEmbed(nn.Module):
    """Splits the image into patches and embeds them."""
    def __init__(self, img_size, patch_size, in_chans=1, embed_dim=64):
        super().__init__()

        num_patches = (img_size // patch_size) ** 2
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.embed_dim = embed_dim

        # Convolution to handle patch extraction and linear projection (Embedding)
        self.proj = nn.Conv2d(
            in_chans,
            embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )

    def forward(self, x):
        # Input x shape: [B, C, H, W] (e.g., [64, 1, 28, 28] for MNIST)

        # 1. Patch Extraction + Projection (B, E, H/P, W/P)
        x = self.proj(x)

        # 2. Flatten patches (B, E, N_patches)
        x = x.flatten(2)

        # 3. Transpose to sequence format (B, N_patches, E)
        x = x.transpose(1, 2)

        return x

In [4]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.1):
        super().__init__()

        self.norm1 = nn.LayerNorm(embed_dim)
        # Multi-Head Attention (Self-Attention)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)

        self.norm2 = nn.LayerNorm(embed_dim)

        # Multi-Layer Perceptron (Feed Forward Network)
        mlp_hidden_dim = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        # MHA with Residual Connection
        x_norm = self.norm1(x)
        attn_output, _ = self.attn(x_norm, x_norm, x_norm)
        x = x + attn_output

        # MLP with Residual Connection
        x = x + self.mlp(self.norm2(x))
        return x

In [5]:
class ViTMini(nn.Module):
    def __init__(self,
                 img_size=28,
                 patch_size=7, # 28 / 7 = 4x4 patches = 16 patches
                 in_chans=1,
                 num_classes=10,
                 embed_dim=64,
                 depth=2,
                 num_heads=4,
                 mlp_ratio=4.0,
                 dropout=0.1):
        super().__init__()

        self.num_features = self.embed_dim = embed_dim

        # 1. Patch Embedding
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim
        )
        num_patches = self.patch_embed.num_patches

        # 2. [CLS] token and Positional Embeddings
        # CLS Token: Used to pool the final sequence representation
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        # Positional Embeddings: Size is num_patches + 1 for the CLS token
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=dropout)

        # 3. Transformer Encoder Blocks
        self.blocks = nn.Sequential(*[
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])

        # 4. Final Classification Head
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        # x shape: [B, C, H, W]
        B = x.shape[0]

        # 1. Patch Embeddings -> x shape: [B, N_patches, E]
        x = self.patch_embed(x)

        # 2. Prepend CLS token and add Positional Embeddings
        # Expand CLS token to match batch size: [B, 1, E]
        cls_tokens = self.cls_token.expand(B, -1, -1)

        # Concatenate CLS token to the front: [B, N_patches + 1, E]
        x = torch.cat((cls_tokens, x), dim=1)

        # Add Positional Embeddings
        x = x + self.pos_embed
        x = self.pos_drop(x)

        # 3. Pass through Transformer Blocks
        x = self.blocks(x)

        # 4. Classification
        # We only take the output corresponding to the CLS token (index 0)
        x = self.norm(x)
        cls_output = x[:, 0]
        x = self.head(cls_output)

        return x

In [6]:
# --- Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# MNIST Images are 28x28. We choose a patch size that divides 28 evenly.
PATCH_SIZE = 7 # 28 / 7 = 4 -> 16 total patches

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Load Data
train_data = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST('./data', train=False, transform=transform)

train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = DataLoader(test_data, batch_size=128, shuffle=False)

# Initialize Model
model = ViTMini(
    img_size=28,
    patch_size=PATCH_SIZE,
    in_chans=1,
    num_classes=10,
    embed_dim=64, # Small embedding size for fast training
    depth=2,      # Shallow transformer
    num_heads=4   # 4 attention heads
).to(device)

# --- Training Loop ---
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
epochs = 5

@execution_timer
def train(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    for data, target in tqdm(loader, desc="Training"):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        log_probs = F.log_softmax(output, dim=1)
        correct_log_probs = log_probs.gather(dim=1, index=target.view(-1, 1)).squeeze()
        loss = -correct_log_probs.mean()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

@execution_timer
def test(model, loader, device):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    accuracy = 100. * correct / len(loader.dataset)
    print(f'\nTest set: Accuracy: {correct}/{len(loader.dataset)} ({accuracy:.2f}%)')
    return accuracy

print("\n--- Starting ViT Training ---")
for epoch in range(epochs):
    loss = train(model, train_loader, criterion, optimizer, device)
    print(f"Epoch {epoch+1} Loss: {loss:.4f}")
    test(model, test_loader, device)

Using device: cpu


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 9.91M/9.91M [00:00<00:00, 37.3MB/s]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 28.9k/28.9k [00:00<00:00, 1.06MB/s]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1.65M/1.65M [00:00<00:00, 9.66MB/s]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 4.54k/4.54k [00:00<00:00, 6.07MB/s]



--- Starting ViT Training ---


Training: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 469/469 [01:04<00:00,  7.31it/s]


Function train executed in 64.1436 seconds
Epoch 1 Loss: 0.7378

Test set: Accuracy: 9174/10000 (91.74%)
Function test executed in 3.7225 seconds


Training: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 469/469 [01:10<00:00,  6.69it/s]


Function train executed in 70.0888 seconds
Epoch 2 Loss: 0.2788

Test set: Accuracy: 9512/10000 (95.12%)
Function test executed in 3.8052 seconds


Training: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 469/469 [01:00<00:00,  7.71it/s]


Function train executed in 60.8065 seconds
Epoch 3 Loss: 0.2015

Test set: Accuracy: 9595/10000 (95.95%)
Function test executed in 3.6441 seconds


Training: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 469/469 [01:00<00:00,  7.76it/s]


Function train executed in 60.4224 seconds
Epoch 4 Loss: 0.1658

Test set: Accuracy: 9651/10000 (96.51%)
Function test executed in 3.5391 seconds


Training: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 469/469 [01:00<00:00,  7.76it/s]


Function train executed in 60.4480 seconds
Epoch 5 Loss: 0.1436

Test set: Accuracy: 9705/10000 (97.05%)
Function test executed in 3.6426 seconds


In [7]:
import heapq

class Node:
  def __init__(self, symbol=None, frequency=None):
    self.symbol = symbol
    self.frequency = frequency
    self.left = None
    self.right = None

  def __lt__(self, other):
    return self.frequency < other.frequency

  def is_leaf(self):
    return self.left is None and self.right is None

def build_tree(leaf_nodes, frequencies):
  # Create a priority queue of nodes
  priority_queue = [Node(val, freq) for val, freq in zip(leaf_nodes, frequencies)]
  heapq.heapify(priority_queue)

  internal_node_counter = 0
  # Build the Huffman tree
  while len(priority_queue) > 1:
    left_child = heapq.heappop(priority_queue)
    right_child = heapq.heappop(priority_queue)
    merged_node = Node(
      symbol=f'Internal Node {internal_node_counter}', frequency=left_child.frequency + right_child.frequency
    )
    merged_node.left = left_child
    merged_node.right = right_child
    heapq.heappush(priority_queue, merged_node)
  return priority_queue[0]

def generate_paths(node, code, path_dict):
  if node is not None:
    if node.symbol is not None and not isinstance(node.symbol, str):
      path_dict[node.symbol] = code
    generate_paths(node.left, code + [0], path_dict)
    generate_paths(node.right, code + [1], path_dict)
  return path_dict

def max_depth(node):
  if node is None:
    return 0
  left_depth = max_depth(node.left)
  right_depth = max_depth(node.right)
  return max(left_depth, right_depth) + 1

In [8]:
classes = list(range(10))
root_node = build_tree(classes, [1] * 10)
paths = generate_paths(root_node, [], {})
print(f"Paths for classes: {paths}")
print(f"All items in vocabulary present in tree: {all(item in paths for item in classes)}")
print(f"Number of items in vocabulary not present in tree: {len([item for item in classes if item not in paths])}")

Paths for classes: {4: [0, 0, 0], 8: [0, 0, 1], 6: [0, 1, 0], 2: [0, 1, 1], 3: [1, 0, 0], 5: [1, 0, 1], 0: [1, 1, 0, 0], 1: [1, 1, 0, 1], 7: [1, 1, 1, 0], 9: [1, 1, 1, 1]}
All items in vocabulary present in tree: True
Number of items in vocabulary not present in tree: 0


In [9]:
from collections import deque

class HierarchicalSoftmaxNodeViT(nn.Module):
    def __init__(self, root, hidden_size):
        super().__init__()
        self.root = root
        self.hidden_size = hidden_size
        print(f"Using hidden size: {self.hidden_size}")
        self.paths = generate_paths(root, [], {})
        self.node_name_map = {}
        self.node_weights = nn.ModuleDict()
        self.param_counter = 0

        def initialize_node_parameters(node):
          if node is None or node.is_leaf():
            return None
          node_str = str(self.param_counter)
          self.node_name_map[node] = node_str
          self.node_weights[node_str] = nn.Linear(self.hidden_size, 1, bias=False)
          self.param_counter += 1
          initialize_node_parameters(node.left)
          initialize_node_parameters(node.right)

        initialize_node_parameters(self.root)
        print(f"HSM initialized with {len(self.node_weights)} internal nodes")


    # NOTE: The forward pass for the loss calculation is slightly simplified
    # for the ViT classification head where the input is [B, E] and the target is [B]
    def forward(self, hidden_state, target_ids):
        total_loss = torch.tensor(0.0)
        total_loss.requires_grad = True
        total_valid_tokens = 0

        for h_i, target_id in zip(hidden_state, target_ids):
          path_step_loss = []
          target = target_id.item()
          if target not in self.paths:
            continue
          choices = self.paths[target]
          curr = self.root
          for choice in choices:
            if curr.is_leaf():
              break
            node_str = self.node_name_map[curr]
            W = self.node_weights[node_str]
            binary_loss = F.binary_cross_entropy_with_logits(
                W(h_i),
                torch.tensor([float(choice)], device=device),
                reduction='sum'
            )
            path_step_loss.append(binary_loss)
            curr = curr.left if not choice else curr.right
          total_loss = total_loss + torch.stack(path_step_loss).sum()
          total_valid_tokens += 1
        return total_loss / max(1, total_valid_tokens)


In [13]:
class ViTMiniHSM(ViTMini):
    def __init__(self, **kwargs):
        # Call base constructor to set up patch embedding and transformer blocks
        super().__init__(**kwargs)

        # 1. Build the HSM Tree (10 classes: 0-9)
        self.root = build_tree(list(range(10)), [1] * 10)
        print(f"Paths for classes: {paths}")
        print(f"All items in vocabulary present in tree: {all(item in paths for item in classes)}")
        print(f"Number of items in vocabulary not present in tree: {len([item for item in classes if item not in paths])}")

        # 2. Replace the standard head with the HSM head
        self.hsm_head = HierarchicalSoftmaxNodeViT(
            root=self.root,
            hidden_size=kwargs['embed_dim'] # Use the CLS token's embedding dimension
        )

        # Remove the standard linear head defined in the base class
        del self.head

    def _greedy_predict(self, hidden_state):
        curr = self.hsm_head.root
        while not curr.is_leaf():
          node_str = self.hsm_head.node_name_map[curr]
          W = self.hsm_head.node_weights[node_str]
          choice = F.sigmoid(W(hidden_state)) > 0.5
          curr = curr.left if not choice else curr.right
        return curr.symbol


    def forward(self, x, targets=None):
        # 1. Base Transformer Forward Pass (same as original ViT)
        B = x.shape[0]
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)
        x = self.blocks(x)

        cls_output = self.norm(x)[:, 0]

        # 3. Head Calculation
        if targets is not None:
            # Targets are the true class labels [B]
            loss = self.hsm_head(cls_output, targets)
            return {"loss": loss}
        else:
            predictions = []
            for i in range(B):
                h_i = cls_output[i]
                predicted_class_id = self._greedy_predict(h_i)
                predictions.append(predicted_class_id)
            predicted_classes = torch.tensor(predictions, dtype=torch.long, device=cls_output.device)
            return predicted_classes

In [11]:
# --- Updated Training and Testing Logic ---

@execution_timer
def train_hsm_vit(model, loader, optimizer, device):
    model.train()
    total_loss = 0
    for data, target in tqdm(loader, desc="Training (HSM)"):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()

        # Model returns a dictionary {"loss": loss_tensor}
        output_dict = model(data, targets=target)
        loss = output_dict['loss']

        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

@execution_timer
def test_hsm_vit(model, loader, device):
    """
    Evaluates the ViT-HSM model.
    """
    model.eval()
    correct = 0
    total_samples = 0

    with torch.no_grad():
        for data, target in tqdm(loader, desc="Testing"):
            data, target = data.to(device), target.to(device)

            # New Step: Call model without targets to trigger the prediction block
            # predicted_classes shape: [B] (Tensor of integer class IDs)
            predicted_classes = model(data)

            # Compare the predicted tensor to the target tensor
            correct += predicted_classes.eq(target).sum().item()
            total_samples += target.size(0)

    accuracy = 100. * correct / total_samples
    print(f'\nðŸ“ˆ HSM Test set: Accuracy: {correct}/{total_samples} ({accuracy:.2f}%)')
    return accuracy

In [14]:
# Initialize Model
model_hsm = ViTMiniHSM(
    img_size=28, patch_size=7, in_chans=1, num_classes=10,
    embed_dim=64, depth=2, num_heads=4
).to(device)

# --- Execute Training ---
epochs = 5
optimizer_hsm = optim.Adam(model_hsm.parameters(), lr=1e-3)

for epoch in range(epochs):
    loss = train_hsm_vit(model_hsm, train_loader, optimizer_hsm, device)
    print(f"Epoch {epoch+1} HSM Loss: {loss:.4f}")
    test_hsm_vit(model_hsm, test_loader, device)

Paths for classes: {4: [0, 0, 0], 8: [0, 0, 1], 6: [0, 1, 0], 2: [0, 1, 1], 3: [1, 0, 0], 5: [1, 0, 1], 0: [1, 1, 0, 0], 1: [1, 1, 0, 1], 7: [1, 1, 1, 0], 9: [1, 1, 1, 1]}
All items in vocabulary present in tree: True
Number of items in vocabulary not present in tree: 0
Using hidden size: 64
HSM initialized with 9 internal nodes


Training (HSM): 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 469/469 [01:37<00:00,  4.83it/s]


Function train_hsm_vit executed in 97.0251 seconds
Epoch 1 HSM Loss: 0.8429


Testing: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 79/79 [00:04<00:00, 16.06it/s]



ðŸ“ˆ HSM Test set: Accuracy: 9112/10000 (91.12%)
Function test_hsm_vit executed in 4.9252 seconds


Training (HSM): 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 469/469 [01:37<00:00,  4.79it/s]


Function train_hsm_vit executed in 97.9310 seconds
Epoch 2 HSM Loss: 0.3382


Testing: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 79/79 [00:04<00:00, 15.98it/s]



ðŸ“ˆ HSM Test set: Accuracy: 9413/10000 (94.13%)
Function test_hsm_vit executed in 4.9489 seconds


Training (HSM): 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 469/469 [01:36<00:00,  4.86it/s]


Function train_hsm_vit executed in 96.4398 seconds
Epoch 3 HSM Loss: 0.2407


Testing: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 79/79 [00:04<00:00, 16.33it/s]



ðŸ“ˆ HSM Test set: Accuracy: 9515/10000 (95.15%)
Function test_hsm_vit executed in 4.8450 seconds


Training (HSM): 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 469/469 [01:35<00:00,  4.93it/s]


Function train_hsm_vit executed in 95.1698 seconds
Epoch 4 HSM Loss: 0.1943


Testing: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 79/79 [00:04<00:00, 16.55it/s]



ðŸ“ˆ HSM Test set: Accuracy: 9619/10000 (96.19%)
Function test_hsm_vit executed in 4.7781 seconds


Training (HSM): 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 469/469 [01:35<00:00,  4.93it/s]


Function train_hsm_vit executed in 95.1129 seconds
Epoch 5 HSM Loss: 0.1631


Testing: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 79/79 [00:04<00:00, 16.41it/s]


ðŸ“ˆ HSM Test set: Accuracy: 9654/10000 (96.54%)
Function test_hsm_vit executed in 4.8181 seconds



