<a href="https://colab.research.google.com/github/tally0818/OFA/blob/main/OFA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

In [None]:
!pip install calflops

In [None]:
from calflops import calculate_flops

In [None]:
def get_top_k_matrices(matrices, k):
  norms = torch.tensor([matrices[i].abs().sum() for i in range(matrices.size(0))])
  top_k_indices = norms.argsort(descending=True)[:k]
  return [matrices[i] for i in top_k_indices]

In [None]:
class ElasticConv(nn.Module):
  def __init__(self, in_channels : int, out_channels : int, max_kernel_size : int = 7, device = 'cuda'):
    super().__init__()
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.max_kernel_size = max_kernel_size
    self.active_kernel_size = self.max_kernel_size
    self.device = device
    self.stride = 1
    self.p1_kernels = nn.Parameter(torch.randn(self.out_channels,
                                            self.in_channels,
                                            1,
                                            1)).to(self.device)

    self.dep_kernels = nn.Parameter(torch.randn(self.out_channels,
                                            1,
                                            self.max_kernel_size,
                                            self.max_kernel_size)).to(self.device)

    self.p2_kernels = nn.Parameter(torch.randn(self.out_channels,
                                            self.out_channels,
                                            1,
                                            1)).to(self.device)

    self.active_dep_kernels = self.dep_kernels
    self.active_p1_kernels = self.p1_kernels
    self.active_p2_kernels = self.p2_kernels
    self.ReLU6 = nn.ReLU6()

  def forward(self, x):
    x = F.conv2d(x, self.active_p1_kernels, stride=self.stride)
    x = self.ReLU6(x)
    x = F.conv2d(x, self.active_dep_kernels,
                 stride=1,
                 padding = self.active_kernel_size // 2,
                 groups = self.out_channels)
    x = F.conv2d(x, self.active_p2_kernels, stride=1)
    x = self.ReLU6(x)
    return x

In [None]:
class Conv(nn.Module):
  def __init__(self, in_channels : int, out_channels : int, p1_kernels, dep_kernels, p2_kernels, kernel_size : int = 3, stride : int = 1, padding : int = 1):
    super().__init__()
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.p1_kernels = p1_kernels
    self.dep_kernels = dep_kernels
    self.p2_kernels = p2_kernels
    self.kernel_size = kernel_size
    self.stride = stride
    self.padding = padding
    self.ReLU6 = nn.ReLU6()
    self.point_conv1 = nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, stride=self.stride)
    self.point_conv1.weight.data = self.p1_kernels
    self.dep_conv = nn.Conv2d(self.out_channels, self.out_channels, kernel_size = self.kernel_size,
                              padding = self.kernel_size // 2, groups = self.out_channels)
    self.dep_conv.weight.data = self.dep_kernels
    self.point_conv2 = nn.Conv2d(self.out_channels, self.out_channels, kernel_size=1, stride=1)
    self.point_conv2.weight.data = self.p2_kernels
  def forward(self, x):
    x = self.point_conv1(x)
    x = self.ReLU6(x)
    x = self.dep_conv(x)
    x = self.point_conv2(x)
    x = self.ReLU6(x)
    return x

In [None]:
def get_fixed_Conv(ElasticConv : ElasticConv)->Conv:
  p1_kernels = ElasticConv.active_p1_kernels.clone()
  dep_kernels = ElasticConv.active_dep_kernels.clone()
  p2_kernels = ElasticConv.active_p2_kernels.clone()
  return Conv(ElasticConv.in_channels,
              ElasticConv.out_channels,
              p1_kernels,
              dep_kernels,
              p2_kernels,
              ElasticConv.active_kernel_size)

In [None]:
class ElasticSqueezeAndExcite(nn.Module):
  def __init__(self, channels : int, reduction : int = 4, device = 'cuda'):
    super().__init__()
    self.channels = channels
    self.reduction = reduction
    self.device = device
    self.reduced_channels = self.channels // self.reduction
    self.fc1_weights = nn.Parameter(torch.randn(self.reduced_channels, self.channels, 1, 1)).to(device)
    self.fc2_weights = nn.Parameter(torch.randn(self.channels, self.reduced_channels, 1, 1)).to(device)
    self.active_fc1_weights = self.fc1_weights
    self.active_fc2_weights = self.fc2_weights
    self.active_channels = self.channels
    self.active_reduced_channels = self.reduced_channels

  def _process_kernels(self, weights, out_ch_count):
    '''
    weights : (channels, in_channels, s, s) -> (out_ch_count, in_channels, s, s) or
    weights : (channels, s, s) -> (out_ch_count, s, s)
    '''
    top_weights = get_top_k_matrices(weights, out_ch_count)
    return torch.stack(top_weights).to(self.device)

  def shrink(self, elastic_channels):
    '''
    active_fc1_weights : (reduced_channels, channels, 1, 1) -> (elastic_channels//4, elastic_channels, 1, 1)
    active_fc2_weights : (channels, reduced_channels, 1, 1) -> (elastic_channels, elastic_channels//4, 1, 1)
    '''
    if elastic_channels == self.active_channels:
      return

    self.active_channels = elastic_channels
    self.active_reduced_channels = self.active_channels // self.reduction

    fc1_in_weights = self._process_kernels(self.fc1_weights, self.active_reduced_channels)
    processed_fc1_weights = []
    for ch_weight in fc1_in_weights:
      reduced_top = get_top_k_matrices(ch_weight, self.active_channels)
      processed_fc1_weights.append(torch.stack(reduced_top))

    self.active_fc1_weights = torch.stack(processed_fc1_weights).to(self.device)

    fc2_in_weights = self._process_kernels(self.fc2_weights, self.active_channels)
    processed_fc2_weights = []
    for ch_weight in fc2_in_weights:
      out_top = get_top_k_matrices(ch_weight, self.active_reduced_channels)
      processed_fc2_weights.append(torch.stack(out_top))

    self.active_fc2_weights = torch.stack(processed_fc2_weights).to(self.device)

  def forward(self, x):
    scale = F.adaptive_avg_pool2d(x, 1)
    scale = F.conv2d(scale, self.active_fc1_weights)
    scale = F.relu(scale)
    scale = F.conv2d(scale, self.active_fc2_weights)
    scale = F.hardsigmoid(scale)
    return x * scale

In [None]:
class SqueezeAndExcite(nn.Module):
  def __init__(self, channels : int, reduction : int = 4, fc1_weights = None, fc2_weights = None):
    super().__init__()
    self.channels = channels
    self.reduction = reduction
    self.reduced_channels = max(1, self.channels // self.reduction)
    self.fc1 = nn.Conv2d(self.channels, self.reduced_channels, 1, bias=False)
    self.fc2 = nn.Conv2d(self.reduced_channels, self.channels, 1, bias=False)
    self.fc1.weight.data = fc1_weights
    self.fc2.weight.data = fc2_weights

  def forward(self, x):
    scale = F.adaptive_avg_pool2d(x, 1)
    scale = self.fc1(scale)
    scale = F.relu(scale)
    scale = self.fc2(scale)
    scale = F.hardsigmoid(scale)
    return x * scale

In [None]:
def get_fixed_SqueezeAndExcite(ElasticSqueezeAndExcite : ElasticSqueezeAndExcite) -> SqueezeAndExcite:
  return SqueezeAndExcite(
      ElasticSqueezeAndExcite.active_channels,
      ElasticSqueezeAndExcite.reduction,
      ElasticSqueezeAndExcite.active_fc1_weights,
      ElasticSqueezeAndExcite.active_fc2_weights
  )

In [None]:
class ElasticMBblock(nn.Module): #layer
  def __init__(self, in_channels : int, max_width : int = 6, max_kernel_size : int = 7, device = 'cuda'):
    super().__init__()
    self.in_channels = in_channels
    self.max_width = max_width
    self.max_kernel_size = max_kernel_size
    self.device = device
    self.out_channels = in_channels * self.max_width
    self.transfrom_matrix_725 = nn.Parameter(torch.randn(self.max_kernel_size**2, 25)) # transforming max_size * max_size kernel 2 5*5 kernel
    self.transfrom_matrix_523 = nn.Parameter(torch.randn(25, 9)) # transforming 5*5 kernel 2 3*3 kernel
    self.dep_conv = ElasticConv(self.in_channels, self.out_channels, self.max_kernel_size, self.device)
    self.se = ElasticSqueezeAndExcite(self.out_channels, 4, self.device)
    self.active_kernel_size = self.max_kernel_size
    self.active_width = self.max_width

  def get_transform_matrix(self, elastic_kernel_size : int):
    if elastic_kernel_size == self.max_kernel_size:
      return torch.eye(self.max_kernel_size**2).to(self.device)
    elif elastic_kernel_size == 5:
      return self.transfrom_matrix_725
    elif elastic_kernel_size == 3:
      return torch.matmul(self.transfrom_matrix_725, self.transfrom_matrix_523).to(self.device)
    else:
      raise ValueError("Unsupported kernel size")

  def _process_kernels(self, kernels, in_ch_count, transform_matrix=None):
    top_channels = get_top_k_matrices(kernels, self.out_channels)
    results = []
    for channel in top_channels:
      if transform_matrix is not None: # For depthwise kernels that need transformation
        transformed_channels = []
        for kernel in channel:
          transformed = torch.matmul(kernel.view(1, -1), transform_matrix).view(self.active_kernel_size, self.active_kernel_size)
          transformed_channels.append(transformed)
        results.append(torch.stack(transformed_channels))
      else: # For pointwise kernels
        in_channels_top = get_top_k_matrices(channel, in_ch_count)
        results.append(torch.stack(in_channels_top))

    return torch.stack(results).to(self.device)

  def shrink(self, elastic_kernel_size: int, elastic_width: int):
    self.active_kernel_size = elastic_kernel_size
    self.active_width = elastic_width
    self.dep_conv.active_kernel_size = elastic_kernel_size
    self.out_channels = self.in_channels * elastic_width
    self.dep_conv.out_channels = self.out_channels
    self.se.shrink(self.out_channels)
    transform_matrix = self.get_transform_matrix(elastic_kernel_size)
    self.dep_conv.active_p1_kernels = self._process_kernels(self.dep_conv.p1_kernels, self.in_channels)
    self.dep_conv.active_dep_kernels = self._process_kernels(self.dep_conv.dep_kernels, None, transform_matrix)
    self.dep_conv.active_p2_kernels = self._process_kernels(self.dep_conv.p2_kernels, self.out_channels)


  def forward(self, x):
    x_copy = x.clone()
    x = self.dep_conv(x)
    x = self.se(x)
    if self.in_channels == self.out_channels:
      x += x_copy # optional residual connection
    return x


In [None]:
class MBblock(nn.Module): #layer
  def __init__(self, in_channels : int, out_channels : int, conv : Conv, se : SqueezeAndExcite, device = 'cuda'):
    super().__init__()
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.conv = conv
    self.se = se
    self.device = device

  def forward(self, x):
    x_copy = x.clone()
    x = self.conv(x)
    x = self.se(x)
    if self.in_channels == self.out_channels:
      x += x_copy # optional residual connection
    return x


In [None]:
def get_fixed_MBblock(ElasticMBblock : ElasticMBblock)->MBblock:
  conv = get_fixed_Conv(ElasticMBblock.dep_conv)
  se = get_fixed_SqueezeAndExcite(ElasticMBblock.se)
  return MBblock(ElasticMBblock.in_channels, ElasticMBblock.out_channels, conv, se, ElasticMBblock.device)

In [None]:
class ElasticUnit(nn.Module):
  def __init__(self,
               in_channels : int,
               max_width : int = 6,
               max_kernel_size : int = 7,
               max_depth : int = 4,
               device = 'cuda'):
    super().__init__()
    self.in_channels = in_channels
    self.max_width = max_width
    self.max_kernel_size = max_kernel_size
    self.max_depth = max_depth
    self.device = device
    self.out_channels = 0
    self.layers = nn.ModuleList()
    self._set_layers()
    self.active_layers = self.layers

  def _set_layers(self):
    tmp_channels = self.in_channels
    for i in range(self.max_depth):
      layer = ElasticMBblock(tmp_channels, self.max_width, self.max_kernel_size, self.device)
      self.layers.append(layer)
      tmp_channels = layer.out_channels
    self.out_channels = tmp_channels

  def forward(self, x):
    for layer in self.active_layers:
      x = layer(x)
    return x

In [None]:
class Unit(nn.Module):
  def __init__(self, in_channels : int, out_channels : int, layers : nn.ModuleList):
    super().__init__()
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.layers = layers

  def forward(self, x):
    for layer in self.layers:
      x = layer(x)
    return x

In [None]:
def get_fixed_Unit(ElasticUnit : ElasticUnit)->Unit:
  layers = nn.ModuleList()
  for layer in ElasticUnit.active_layers:
    layers.append(get_fixed_MBblock(layer))
  return Unit(ElasticUnit.in_channels, ElasticUnit.out_channels, layers)

In [None]:
class OFAnet(nn.Module):
  def __init__(self,
               num_units : int = 5,
               in_channels : int = 3,
               num_classes : int = 1000,
               max_depth : int = 4,
               max_width : int = 6,
               max_kernel_size : int = 7,
               device = 'cuda',
               init_channels : int = 8,
               fixed_reduction : bool = False
               ):
    super().__init__()
    self.num_units = num_units
    self.in_channels = in_channels
    self.num_classes = num_classes
    self.max_depth = max_depth
    self.max_width = max_width
    self.max_kernel_size = max_kernel_size
    self.device = device
    self.init_channels = init_channels
    self.active_channels = self.init_channels
    self.fixed_reduction = fixed_reduction
    self.stem = nn.Sequential(nn.Conv2d(self.in_channels, self.init_channels, kernel_size=3, stride=2, padding=1),
                              ElasticMBblock(in_channels = self.init_channels, max_width = 1)
                              )

    self.units = nn.ModuleList()

    self._set_units()
    self.final_layer = ElasticMBblock(in_channels = self.active_channels,
                                      max_width = self.max_width,
                                      max_kernel_size = self.max_kernel_size)
    self.global_pool = nn.AdaptiveAvgPool2d(1)
    self.classifier = nn.Linear(self.max_width * self.active_channels, num_classes)

  def _set_units(self):
    for i in range(self.num_units):
      unit = ElasticUnit(self.active_channels, self.max_width, self.max_kernel_size, self.max_depth, self.device)
      if self.fixed_reduction:
        unit.layers[0].dep_conv.stride = 2
      self.units.append(unit)
      self.active_channels = unit.out_channels

  def _process_layer_for_shrinking(self, layer, active_channels, config):
    elastic_kernel_size = config["k"]
    elastic_width = config["w"]
    layer.in_channels = active_channels
    layer.out_channels = layer.in_channels * elastic_width
    layer.shrink(elastic_kernel_size, elastic_width)
    return layer.out_channels

  def forward(self, x):
    x = self.stem(x)
    for unit in self.units:
      x = unit(x)
    x = self.final_layer(x)
    x = self.global_pool(x)
    x = x.view(x.size(0), -1)
    x = self.classifier(x)
    return x

In [None]:
class Subnet(nn.Module):
  def __init__(self, stem, units, final_layer, global_pool, classifier):
    super().__init__()
    self.in_channels = stem[0].in_channels
    self.stem = stem
    self.units = units
    self.final_layer = final_layer
    self.global_pool = global_pool
    self.classifier = classifier

  def forward(self, x):
    x = self.stem(x)
    for unit in self.units:
      x = unit(x)
    x = self.final_layer(x)
    x = self.global_pool(x)
    x = x.view(x.size(0), -1)
    x = self.classifier(x)
    return x

In [None]:
def get_fixed_Subnet(OFAnet: OFAnet) -> Subnet:
    stem_conv = OFAnet.stem[0]
    elastic_block = OFAnet.stem[1]
    fixed_stem_block = get_fixed_MBblock(elastic_block)
    fixed_stem = nn.Sequential(stem_conv, fixed_stem_block)
    fixed_units = nn.ModuleList()

    for unit in OFAnet.units:
        fixed_unit = get_fixed_Unit(unit)
        fixed_units.append(fixed_unit)
    fixed_final_layer = get_fixed_MBblock(OFAnet.final_layer)
    global_pool = OFAnet.global_pool
    classifier = OFAnet.classifier

    return Subnet(
        stem=fixed_stem,
        units=fixed_units,
        final_layer=fixed_final_layer,
        global_pool=global_pool,
        classifier=classifier
    )

In [None]:
def ProgressiveShrinking(supernet: OFAnet, configs: list[list[dict]]) -> Subnet:
  # shrink depth via keeping first D layers
  for unit_idx, unit in enumerate(supernet.units):
    unit_config = configs[unit_idx]
    elastic_depth = len(unit_config)
    unit.active_layers = unit.layers[:elastic_depth]

  # shrink kernel size via transformation matrix & shrink width via L1 norm of kernels
  active_channels = supernet.init_channels
  for unit_idx, unit in enumerate(supernet.units):
    unit.in_channels = active_channels
    for layer_idx, layer in enumerate(unit.active_layers):
      layer_config = configs[unit_idx][layer_idx]
      elastic_kernel_size = layer_config["k"]
      elastic_width = layer_config["w"]
      if not (layer.active_kernel_size == elastic_kernel_size and
              layer.active_width == elastic_width and
              layer.in_channels == active_channels):
        layer.in_channels = active_channels
        layer.shrink(elastic_kernel_size, elastic_width)
      active_channels = layer.out_channels
    unit.out_channels = active_channels

  # Handle final layer
  supernet.final_layer.in_channels = active_channels
  final_p1_kernels = supernet.final_layer.dep_conv.p1_kernels.clone()
  supernet.final_layer.dep_conv.active_p1_kernels = supernet.final_layer._process_kernels(final_p1_kernels, active_channels)

  # get Subnet
  subnet = get_fixed_Subnet(supernet)
  return subnet

In [None]:
def get_simple_configs(num_units, k, d, w):
  return [[{"k":k, "w":w} for layer in range(d)] for unit in range(num_units)]

In [None]:
def iterate_config_space(num_units, config_space : list[list[int]]):
  configs = []
  for k in config_space[0]:
    for d in config_space[1]:
      for w in config_space[2]:
         configs.append(get_simple_configs(num_units, k, d, w))
  return configs

In [None]:
def sample_config(num_units, config_space : list[list[int]], fixed_dims : list = []):
  d = np.random.choice(config_space[1]) if "d" not in fixed_dims else config_space[1][-1]
  config = []
  for unit in range(num_units):
    unit_config = []
    for layer in range(d):
      k = np.random.choice(config_space[0]) if "k" not in fixed_dims else config_space[0][-1]
      w = np.random.choice(config_space[2]) if "w" not in fixed_dims else config_space[2][-1]
      unit_config.append({"k":k, "w":w})
    config.append(unit_config)
  return config

In [None]:
def train_OFAnet(supernet, config_space, epochs, dataloader, criterion, optimizers, schedulers, device):
  fixed_dims = ["d", "w"]
  supernet.train()
  for stage in range(3):
    optimizer = optimizers[stage]
    scheduler = schedulers[stage]
    for epoch in range(1, epochs + 1):
      config = sample_config(supernet.num_units, config_space, fixed_dims)
      subnet = ProgressiveShrinking(supernet, config)
      subnet.to(device)
      loss = train_one_epoch(subnet, train_loader, criterion, optimizer, device)
      scheduler.step()
      if epoch % 5 == 0:
        print(f"Stage {stage + 1}, Epoch [{epoch}/{epochs}], Loss: {loss:.4f}")
    fixed_dims = fixed_dims[:-1]

In [None]:
def train_one_epoch(model, dataloader, criterion, optimizer, device):
  model.train()
  running_loss = 0.0
  for images, targets in dataloader:
    images, targets = images.to(device), targets.to(device)
    optimizer.zero_grad()
    outputs = model(images)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()
    running_loss += loss.item() * images.size(0)
  return running_loss / len(dataloader.dataset)

In [None]:
def evaluate(model, dataloader, device):
  model.eval()
  correct = 0
  total = 0
  with torch.no_grad():
    for images, targets in dataloader:
      images, targets = images.to(device), targets.to(device)
      outputs = model(images)
      _, preds = torch.max(outputs, 1)
      correct += (preds == targets).sum().item()
      total += targets.size(0)
  return correct / total

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                     (0.2023, 0.1994, 0.2010))
    ])

In [None]:
transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                     (0.2023, 0.1994, 0.2010))
    ])

In [None]:
train_dataset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform_test)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=4)

100%|██████████| 170M/170M [00:13<00:00, 13.1MB/s]


In [None]:
supernet = OFAnet(num_units=1,
                  in_channels=3,
                  num_classes=10,
                  max_depth=3,
                  max_width=4,
                  max_kernel_size=7,
                  fixed_reduction = True)
supernet.to(device)

OFAnet(
  (stem): Sequential(
    (0): Conv2d(3, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ElasticMBblock(
      (dep_conv): ElasticConv(
        (ReLU6): ReLU6()
      )
      (se): ElasticSqueezeAndExcite()
    )
  )
  (units): ModuleList(
    (0): ElasticUnit(
      (layers): ModuleList(
        (0-2): 3 x ElasticMBblock(
          (dep_conv): ElasticConv(
            (ReLU6): ReLU6()
          )
          (se): ElasticSqueezeAndExcite()
        )
      )
      (active_layers): ModuleList(
        (0-2): 3 x ElasticMBblock(
          (dep_conv): ElasticConv(
            (ReLU6): ReLU6()
          )
          (se): ElasticSqueezeAndExcite()
        )
      )
    )
  )
  (final_layer): ElasticMBblock(
    (dep_conv): ElasticConv(
      (ReLU6): ReLU6()
    )
    (se): ElasticSqueezeAndExcite()
  )
  (global_pool): AdaptiveAvgPool2d(output_size=1)
  (classifier): Linear(in_features=2048, out_features=10, bias=True)
)

In [None]:
config_space = [[5, 7], # k
                [2, 3], # d
                [2, 4]  # w
                 ]

In [None]:
# Now let's train the OFA network
criterion = nn.CrossEntropyLoss()
epochs = 180

optimizer = torch.optim.SGD(supernet.parameters(), lr=2.6, momentum=0.9, weight_decay=3e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
for epoch in range(1, epochs + 1):
  loss = train_one_epoch(supernet, train_loader, criterion, optimizer, device)
  scheduler.step()
  if epoch % 10 == 0:
    print(f"Epoch [{epoch}/{epochs}], Loss: {loss:.4f}")
print("done training supernet")
acc = evaluate(supernet, test_loader, device)
print("supernet accuracy:",acc)

Epoch [10/180], Loss: 24300.6246
Epoch [20/180], Loss: 27081.0991
Epoch [30/180], Loss: 25827.4613
Epoch [40/180], Loss: 25561.2965
Epoch [50/180], Loss: 23562.3627
Epoch [60/180], Loss: 20719.7780
Epoch [70/180], Loss: 17261.6188
Epoch [80/180], Loss: 14938.7788
Epoch [90/180], Loss: 12885.2181
Epoch [100/180], Loss: 9825.3774
Epoch [110/180], Loss: 8000.0122
Epoch [120/180], Loss: 6313.3398
Epoch [130/180], Loss: 4094.1782
Epoch [140/180], Loss: 2416.6540
Epoch [150/180], Loss: 1239.7816
Epoch [160/180], Loss: 456.5941
Epoch [170/180], Loss: 76.6134
Epoch [180/180], Loss: 54.8479
done training supernet
supernet accuracy: 0.2149


In [None]:
ft_epochs = 25
optimizers = [torch.optim.SGD([param for name, param in supernet.named_parameters() if "matrix" in name], lr=0.96, momentum=0.9, weight_decay=3e-5),
              torch.optim.SGD(supernet.parameters(), lr=0.08, momentum=0.9, weight_decay=3e-5),
              torch.optim.SGD(supernet.parameters(), lr=0.08, momentum=0.9, weight_decay=3e-5)]
schedulers = [torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=ft_epochs) for optimizer in optimizers]

train_OFAnet(supernet, config_space, ft_epochs, train_loader, criterion, optimizers, schedulers, device)

Stage 1, Epoch [5/25], Loss: 253.8057
Stage 1, Epoch [10/25], Loss: 468.5583
Stage 1, Epoch [15/25], Loss: 468.4309
Stage 1, Epoch [20/25], Loss: 457.4867
Stage 1, Epoch [25/25], Loss: 285.8737
Stage 2, Epoch [5/25], Loss: 647.3961
Stage 2, Epoch [10/25], Loss: 649.8146
Stage 2, Epoch [15/25], Loss: 266.9540
Stage 2, Epoch [20/25], Loss: 59.0908
Stage 2, Epoch [25/25], Loss: 66.2991
Stage 3, Epoch [5/25], Loss: 939.9427
Stage 3, Epoch [10/25], Loss: 513.8098
Stage 3, Epoch [15/25], Loss: 311.1151
Stage 3, Epoch [20/25], Loss: 54.9380
Stage 3, Epoch [25/25], Loss: 69.2655


In [None]:
torch.save(supernet.state_dict(), 'supernet.pt')

In [None]:
'''
supernet = OFAnet(num_units=1,
                  in_channels=3,
                  num_classes=10,
                  max_depth=3,
                  max_width=4,
                  max_kernel_size=7,
                  fixed_reduction = True)
supernet.load_state_dict(torch.load('supernet.pt'))
'''

"\nsupernet = OFAnet(num_units=1,\n                  in_channels=3,\n                  num_classes=10,\n                  max_depth=3,\n                  max_width=4,\n                  max_kernel_size=7,\n                  fixed_reduction = True)\nsupernet.load_state_dict(torch.load('supernet.pt'))\n"

In [None]:
fine_tune_epochs = 25
configs = iterate_config_space(supernet.num_units, config_space)
input_shape = (1, 3, 32, 32)
for config in configs:
  print("(k, d, w):", (config[0][0]["k"], len(config[0]), config[0][0]["w"]))
  subnet = ProgressiveShrinking(supernet, config)
  subnet.to(device)
  flops, macs, params = calculate_flops(model= supernet,
                                      input_shape=input_shape,
                                      output_as_string=True,
                                      print_detailed=False,
                                      output_precision=4)
  print("Subnet FLOPs:%s   MACs:%s   Params:%s \n" %(flops, macs, params))
  criterion = nn.CrossEntropyLoss()
  acc = evaluate(subnet, test_loader, device)
  print("acc before fine-tuning:", acc)
  ft_optimizer = torch.optim.SGD(subnet.parameters(), lr=2e-5, momentum=0.9, weight_decay=3e-5)
  ft_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(ft_optimizer, T_max=fine_tune_epochs)
  for epoch in range(1, fine_tune_epochs + 1):
    loss = train_one_epoch(subnet, train_loader, criterion, ft_optimizer, device)
    ft_scheduler.step()
  acc = evaluate(subnet, test_loader, device)
  print("acc after fine-tuning:", acc)


(k, d, w): (5, 2, 2)

------------------------------------- Calculate Flops Results -------------------------------------
Notations:
number of parameters (Params), number of multiply-accumulate operations(MACs),
number of floating-point operations (FLOPs), floating-point operations per second (FLOPS),
fwd FLOPs (model forward propagation FLOPs), bwd FLOPs (model backward propagation FLOPs),
default model backpropagation takes 2.00 times as much computation as forward propagation.

Total Training Params:                                                  27.96 K 
fwd MACs:                                                               281.759 MMACs
fwd FLOPs:                                                              563.791 MFLOPS
fwd+bwd MACs:                                                           845.277 MMACs
fwd+bwd FLOPs:                                                          1.6914 GFLOPS
----------------------------------------------------------------------------------------