## Customize Pre-trained Models

This example shows how to **customize your own pre-trained model** for new ideas. Tailor and integrate *any* **add-in** extra module within the vast pre-trained model **with lightning speed**.

![tutorials_overview](./assests/tutorials_overview.png)

### Introduce the Custom Model

Let's begin with a three-layer Multilayer Perceptron (MLP).

Although a multi-layer perceptron is not a good image learner, we can quickly get started with it. For other custom networks, we can also make similar designs and modifications by analogy. 

+ Run the code block below to customize the model:

In [1]:
import torch.nn as nn
class MLP(nn.Module):
    """
    MLP Class
    ==============

    Multilayer Perceptron (MLP) model for image (224x224) classification tasks.
    
    Args:
        args (object): Custom arguments or configurations.
        num_classes (int): Number of output classes.
    """
    def __init__(self, args, num_classes):
        super(MLP, self).__init__()
        self.args = args
        self.image_size = 224
        self.fc1 = nn.Linear(self.image_size * self.image_size * 3, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, num_classes)

    def forward(self, x):
        """
        Forward pass of the model.
        
        Args:
            x (torch.Tensor): Input tensor.
        
        Returns:
            torch.Tensor: Output logits.
        """
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = nn.ReLU()(x)
        x = self.fc2(x)
        x = nn.ReLU()(x)
        x = self.fc3(x)
        return x

![Custom Multilayer Perceptron (MLP) Architecture](./assests/tutorials_mlp.png)
*Figure 1: Custom Multilayer Perceptron (MLP) Architecture*

Now, expand models from **fleeting moments of inspiration**.

We will customize and modify the network structure through a few lines of code from **ZhiJian**.

The additional auxiliary structures are also implemented based on the PyTorch framework. The auxiliary structures inherit the base class `AddinBase`, which integrates some basic methods for data access.

### Design Additional Add-in Modules

+ Run the code block below to customize add-in modules and entry points for the model.

In [2]:
from zhijian.models.addin.module.base import AddinBase
class MLPAddin(AddinBase):
    """
    MLPAddin Class
    ==============

    Multilayer Perceptron (MLP) add-in.

    Args:
        config (object): Custom configuration or arguments.
        model_config (object): Configuration specific to the model.
    """
    def __init__(self, config, model_config):
        super(MLPAddin, self).__init__()

        self.config = config
        self.embed_dim = model_config.hidden_size

        self.reduction_dim = 16

        self.fc1 = nn.Linear(self.embed_dim, self.reduction_dim)
        if config.mlp_addin_output_size is not None:
            self.fc2 = nn.Linear(self.reduction_dim, config.mlp_addin_output_size)
        else:
            self.fc2 = nn.Linear(self.reduction_dim, self.embed_dim)

    def forward(self, x):
        """
        Forward pass of the MLP add-in.

        Args:
            x (tensor): Input tensor.

        Returns:
            tensor: Output tensor after passing through the MLP add-in.
        """
        identity = x 
        out = self.fc1(identity)
        out = nn.ReLU()(out)
        out = self.fc2(out)

        return out

    def adapt_input(self, module, inputs):
        """
        Hook function to adapt the input data before it enters the module.

        Args:
            module (nn.Module): The module being hooked.
            inputs (tuple): (Inputs before the module,).

        Returns:
            tensor: Adapted input tensor after passing through the MLP add-in.
        """
        x = inputs[0]
        return self.forward(x)

    def adapt_output(self, module, inputs, outputs):
        """
        Hook function to adapt the output data after it leaves the module.

        Args:
            module (nn.Module): The module being hooked.
            inputs (tuple): (Inputs before the module,).
            outputs (tensor): Outputs after the module.

        Returns:
            tensor: Adapted output tensor after passing through the MLP add-in.
        """
        return self.forward(outputs)
    
    def adapt_across_input(self, module, inputs):
        """
        Hook function to adapt the data across the modules.

        Args:
            module (nn.Module): The module being hooked.
            inputs (tuple): (Inputs before the module,).

        Returns:
            tensor: Adapted input tensor after adding the MLP add-in output to the subsequent module.
        """
        x = inputs[0]
        x = x + self.forward(self.inputs_cache)
        return x

    def adapt_across_output(self, module, inputs, outputs):
        """
        Hook function to adapt the data across the modules.

        Args:
            module (nn.Module): The module being hooked.
            inputs (tuple): (Inputs before the module,).
            outputs (tensor): Outputs after the module.

        Returns:
            tensor: Adapted input tensor after adding the MLP add-in output to the previous module.
        """
        outputs = outputs + self.forward(self.inputs_cache)
        return outputs

In the extended auxiliary structure `MLPAddin` mentioned above, we add a low-rank bottleneck (consisting of two linear layers, with a reduced dimension in the middle) inspired by efficient parameter methods like *Adapter* or *LoRA*. We define and implement this in the `__init__` and `forward` functions.

![Additional Auxiliary Structure Example](./assests/tutorials_addin_structure.png)
*Figure 2: Additional Auxiliary Structure Example*

As shown above, the `hook` methods starting with `adapt_` are our entry functions. They serve as hooks to attach the extended modules to the base model. We will further explain their roles in the following text.

### Deploy the Inter-layer Insertion & Cross-layer Concatenation Points

We aim to customize our model by **inter-layer insertion** and **cross-layer concatenation** of the auxiliary structures at different positions within the base model (such as the custom MLP mentioned earlier). When configuring the insertion or concatenation positions, **ZhiJian** provides **a minimalistic one-line configuration syntax**.

The syntax for configuring add-in module into the base model is as follows. We will start with one or two examples and gradually understand the meaning of each configuration part.

+ *Inter-layer Insertion*:

  ```python
  >>> (MLPAddin.adapt_input): ...->{inout1}(fc2)->...
  ```

  ![Additional Add-in Structure - Inter-layer Insertion 1](./assests/tutorials_mlp_addin_1.png)
  *Figure 3: Additional Add-in Structure - Inter-layer Insertion 1*


  ```python
  >>> (MLPAddin.adapt_input): ...->{inout1}(fc2)->...
  ```


  ![Additional Add-in Structure - Inter-layer Insertion 2](./assests/tutorials_mlp_addin_2.png)
  *Figure 4: Additional Add-in Structure - Inter-layer Insertion 2*


+ *Cross-layer Insertion*:

  ```python
  >>> (MLPAddin.adapt_across_input): ...->(fc1){in1}->...->{out1}(fc3)->...
  ```

  ![Additional Add-in Structure - Inter-layer Insertion 3](./assests/tutorials_mlp_addin_3.png)
  *Figure 5: Additional Add-in Structure - Cross-layer Concatenation*


#### Base Module: `->(fc1)`

Consider a base model implemented based on the PyTorch framework, where the representation of each layer and module in the model is straightforward：


+ As shown in the figure, the print command can output the defined names of the model structure:

  ```python
  $ print(model)
  ```

+ The structure of some classic backbone can be represented as follows


  + MLP:

    ```python
    >>> input->(fc1)->(fc2)->(fc3)->output
    ```

  + ViT `block[i]`:
  
    ```python
    >>> input->...->(block[i].norm1)->
            (block[i].attn.qkv)->(block[i].attn.attn_drop)->(block[i].attn.proj)->(block[i].attn.proj_drop)->
            (block[i].ls1)->(block[i].drop_path1)->
                (block[i].norm2)->
                (block[i].mlp.fc1)->(block[i].mlp.act)->(block[i].mlp.drop1)->(block[i].mlp.fc2)->(block[i].mlp.drop2)->
                    (block[i].ls2)->(block[i].drop_path2)->...->output
    ```

#### Default Module: `...`

In the configuration syntax of **ZhiJian**, the `...` can be used to represent the default layer or module.

+ For example, when we only focus on the `(fc2)` module in MLP and the `(block[i].mlp.fc2)` module in ViT:

  + MLP:

    ```python
    >>> ...->(fc2)->...
    ```
  + ViT:
  
    ```python
    >>> ...->(block[i].mlp.fc2)->...
    ```

#### Insertion & Concatenation Function: `():`

Considering the custom auxiliary structure `MLPAddin` mentioned above, the functions starting with `adapt_` will serve as the processing center that **insert** and **concatenate** into the base model.


+ There are primarily two types of parameter passing methods:

  ```python
  def adapt_input(self, module, inputs):
      """
      Args:
          module (nn.Module): The module being hooked.
          inputs (tuple): (Inputs before the module,).
      """
      ...
  
  def adapt_output(self, module, inputs, outputs):
      """
      Args:
          module (nn.Module): The module being hooked.
          inputs (tuple): (Inputs before the module,).
          outputs (tensor): Outputs after the module.
      """
      ...
  ```

  where

  + `adapt_input(self, module, inputs)` is generally set before the module and is called before the data enters the module to process inputs and truncate the `input`.

  + `adapt_output(self, module, inputs, outputs)` is generally set before the module and is called before the data enters the module to process outputs and truncate the `output`.

These functions will be "hooked" into the base model in the main method of configuring the module, serving as key connectors between the base model and the auxiliary structure.

#### Insertion & Concatenation Point: `{}`

Consider an independent extended auxiliary structure (such as the `MLPAddin` mentioned above), its **insertion or concatenation points** with the base network must consist of *"Data Input"* and *"Data Output"* where:

+ **"Data Input"** refers to the network features input into the extended auxiliary structure.
+ **"Data Output"** refers to the adapted features output from the auxiliary structure back to the base network.


Next, let's use some configuration examples of MLP to illustrate the syntax and functionality of **ZhiJian** for **module integration**:

#### Inter-layer Insertion: `inout`

+ As shown in the above figure, the configuration expression is:

  ```python
  >>> (MLPAddin.adapt_input): ...->{inout1}(fc2)->...
  ```

  where

  + `{inout1}` refers to the position which gets the base model features (or output, at any layer or module).
  
    It denotes the *"Data Input"* and *"Data Output"*. The configuration can be `{inoutx}`, where `x` represents the x<sup>th</sup> integration point. For example, `{inout1}` represents the first integration point.

  + In the example above, this inter-layer insertion configuration *truncates* the features of the input `fc2` module, *passes* them through, and then return to the `fc2` module. At this point, the original `fc2` features no longer enter.

#### Cross-layer Concatenation `in`, `out`

+ As shown in the above figure, the configuration expression is:

  ```python
  >>> (MLPAddin.adapt_across_input): ...->(fc1){in1}->...->{out1}(fc3)->...
  ```

  where

  + `{in1}`: represents the integration point where the base network features (or output, at any layer or module) *enter* the additional add-in structure.
  
    It denotes the *"Data Input"*. The configuration can be `{inx}`, where `x` represents the x\ :sup:`th` integration point. For example, `{in1}` represents the first integration point.

  + `{out1}`: represent the integration points where the features processed by the additional add-in structure are *returned* to the base network.

    It denotes the *"Data Output"*. The configuration can be `{outx}`, where `x` represents the x\ :sup:`th` integration point. For example, `{out1}` represents the first integration point.
    
  + This cross-layer concatenation configuration *extracts* the features of the `fc1` module's output, *passes them into* the auxiliary structure, and then *returns* them to the base network before the `fc3` module in the form of residual addition.

+ For a better prompt, let's create a tool function that guides the input first:

In [3]:
def select_from_input(prompt_for_select, valid_selections):
    selections2print = '\n\t'.join([f'[{idx + 1}] {i}' for idx, i in enumerate(valid_selections)])
    while True:
        selected = input(f"Please input a {prompt_for_select}, type 'help' to show the options: ")

        if selected == 'help':
            print(f"Available {prompt_for_select}(s):\n\t{selections2print}")
        elif selected.isdigit() and int(selected) >= 1 and int(selected) <= len(valid_selections):
            selected = valid_selections[int(selected) - 1]
            break
        elif selected in valid_selections:
            break
        else:
            print("Sorry, input not support.")
            print(f"Available {prompt_for_select}(s):\n\t{selections2print}")

    print(f"Your selection: {selected}")
    return selected

available_example_config_blitzs = {
    'Insert between `fc1` and `fc2` layer (performed before `fc2`)': "(MLPAddin.adapt_input): ...->{inout1}(fc2)->...",
    'Insert between `fc1` and `fc2` layer (performed after `fc1`)': "(MLPAddin.adapt_output): ...->(fc1){inout1}->...",
    'Splice across `fc2` layer (performed before `fc2` and `fc3`)': "(MLPAddin.adapt_across_input): ...->{in1}(fc2)->{out1}(fc3)->...",
    'Splice across `fc2` layer (performed after `fc1` and before `fc3`)': "(MLPAddin.adapt_across_input): ...->(fc1){in1}->...->{out1}(fc3)->...",
    'Splice across `fc2` layer (performed before and after `fc2`)': "(MLPAddin.adapt_across_output): ...->{in1}(fc2){out1}->...",
    'Splice across `fc2` layer (performed after `fc1` and `fc2`)': "(MLPAddin.adapt_across_output): ...->(fc1){in1}->(fc2){out1}->...",
}

config_blitz = available_example_config_blitzs[select_from_input('add-in structure', list(available_example_config_blitzs.keys()))] # user input about model

Available add-in structure(s):
	[1] Insert between `fc1` and `fc2` layer (performed before `fc2`)
	[2] Insert between `fc1` and `fc2` layer (performed after `fc1`)
	[3] Splice across `fc2` layer (performed before `fc2` and `fc3`)
	[4] Splice across `fc2` layer (performed after `fc1` and before `fc3`)
	[5] Splice across `fc2` layer (performed before and after `fc2`)
	[6] Splice across `fc2` layer (performed after `fc1` and `fc2`)
Your selection: Splice across `fc2` layer (performed before and after `fc2`)


In [4]:
available_example_reuse_modules = {
    'timm.vit_base_patch16_224_in21k': {
        'add-ins and linear layer': 'addin,fc3',
        'add-ins and the last layer of feature extractor and the linear layer (Partial-1)': 'addin,fc2,fc3',
    }
}

availables       = available_example_reuse_modules['timm.vit_base_patch16_224_in21k']
reuse_keys_blitz = availables[select_from_input('reuse modules', list(availables.keys()))] # user input about reuse modules

Available reuse modules(s):
	[1] add-ins and linear layer
	[2] add-ins and the last layer of feature extractor and the linear layer (Partial-1)
Your selection: add-ins and linear layer


In [5]:
available_datasets = [
    'VTAB-1k.CIFAR-100', 'VTAB-1k.CLEVR-Count', 'VTAB-1k.CLEVR-Distance', 'VTAB-1k.Caltech101', 'VTAB-1k.DTD',
    'VTAB-1k.Diabetic-Retinopathy', 'VTAB-1k.Dmlab', 'VTAB-1k.EuroSAT', 'VTAB-1k.KITTI', 'VTAB-1k.Oxford-Flowers-102',
    'VTAB-1k.Oxford-IIIT-Pet', 'VTAB-1k.PatchCamelyon', 'VTAB-1k.RESISC45', 'VTAB-1k.SUN397', 'VTAB-1k.SVHN',
    'VTAB-1k.dSprites-Location', 'VTAB-1k.dSprites-Orientation', 'VTAB-1k.smallNORB-Azimuth', 'VTAB-1k.smallNORB-Elevation'
] # dataset options.
dataset = select_from_input('dataset', available_datasets)  # user input about dataset
dataset_dir = input(f"Please input your dataset directory: ")   # user input about dataset directory
print(f"Your dataset directory: {dataset_dir}")

Available dataset(s):
	[1] VTAB-1k.CIFAR-100
	[2] VTAB-1k.CLEVR-Count
	[3] VTAB-1k.CLEVR-Distance
	[4] VTAB-1k.Caltech101
	[5] VTAB-1k.DTD
	[6] VTAB-1k.Diabetic-Retinopathy
	[7] VTAB-1k.Dmlab
	[8] VTAB-1k.EuroSAT
	[9] VTAB-1k.KITTI
	[10] VTAB-1k.Oxford-Flowers-102
	[11] VTAB-1k.Oxford-IIIT-Pet
	[12] VTAB-1k.PatchCamelyon
	[13] VTAB-1k.RESISC45
	[14] VTAB-1k.SUN397
	[15] VTAB-1k.SVHN
	[16] VTAB-1k.dSprites-Location
	[17] VTAB-1k.dSprites-Orientation
	[18] VTAB-1k.smallNORB-Azimuth
	[19] VTAB-1k.smallNORB-Elevation
Your selection: VTAB-1k.CIFAR-100
Your dataset directory: /data/zhangyk/data/zhijian


+ Next, we will configure the parameters and proceed with model training and testing:

In [21]:
from zhijian.trainers.base import prepare_args
from zhijian.models.utils import pprint, dict2args
training_mode = 'finetune'
args = dict2args({
    'log_url': 'your/log/directory',             # log directory
    'model': 'timm.vit_base_patch16_224_in21k',  # backbone network
    'config_blitz': config_blitz,                # addin blitz configuration
    'dataset': dataset,                          # dataset
    'dataset_dir': dataset_dir,                  # dataset directory
    'training_mode': training_mode,              # training mode
    'reuse_keys_blitz': reuse_keys_blitz,        # reuse keys blitz configuration
    'optimizer': 'adam',                         # optimizer
    'batch_size': 64,                            # batch size
    'num_workers': 8,                            # num workers
    'max_epoch': 5,                              # max epoch
    'eta_min': 0,                                # eta_min of CosineAnnealingLR
    'lr': 1e-3,                                  # learning rate
    'wd': 5e-5,                                  # weight decay
    'gpu': '0',                                  # gpu id
    'seed': 0,                                   # random seed
    'verbose': True,                             # control the verbosity of the output
    'only_do_test': False                        # test flag
})      

args = prepare_args(args, update_default=True)
pprint(vars(args))

{'aa': None,
 'addins': [{'hook': [['get_pre', 'pre'], ['adapt_across_output', 'post']],
             'location': [['fc2'], ['fc2']],
             'name': 'MLPAddin'}],
 'amp': False,
 'amp_dtype': 'float16',
 'amp_impl': 'native',
 'aot_autograd': False,
 'aug_repeats': 0,
 'aug_splits': 0,
 'batch_size': 64,
 'bce_loss': False,
 'bce_target_thresh': None,
 'bn_eps': None,
 'bn_momentum': None,
 'channels_last': False,
 'checkpoint_hist': 10,
 'class_map': '',
 'clip_grad': None,
 'clip_mode': 'norm',
 'color_jitter': 0.4,
 'config_blitz': '(MLPAddin.adapt_across_output): ...->{in1}(fc2){out1}->...',
 'cooldown_epochs': 0,
 'crop_mode': None,
 'crop_pct': None,
 'cutmix': 0.0,
 'cutmix_minmax': None,
 'data': None,
 'data_dir': None,
 'dataset': 'VTAB-1k.CIFAR-100',
 'dataset_dir': '/data/zhangyk/data/zhijian',
 'dataset_download': False,
 'decay_epochs': 90,
 'decay_milestones': [90, 180, 270],
 'decay_rate': 0.1,
 'dist_bn': 'reduce',
 'drop': 0.0,
 'drop_block': None,
 'drop_connec

+ Run the code block below to configure the GPU and the model (excluding additional auxiliary structures):

In [27]:
import torch
import os
assert torch.cuda.is_available()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
torch.cuda.set_device(int(args.gpu))

from zhijian.data.config import DATASET2NUM_CLASSES
from zhijian.models.backbone.base import ModelWrapper
from zhijian.models.configs.base import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
model       = MLP(args, DATASET2NUM_CLASSES[args.dataset.replace('VTAB-1k.','')])
model       = ModelWrapper(model)
model_args  = dict2args({'hidden_size': 256, 'input_size': (224, 224), 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD})

+ Run the code block below to configure additional auxiliary structures:

In [28]:
from zhijian.models.addin.base import prepare_addins
from zhijian.models.backbone.base import prepare_hook, prepare_gradient, prepare_cuda
args.mlp_addin_output_size = 256
addins, fixed_params = prepare_addins(args, model_args, addin_classes=[MLPAddin])

prepare_hook(args.addins, addins, model, 'addin')
prepare_gradient(args.reuse_keys, model)
device = prepare_cuda(model)

+ Run the code block below to configure the dataset, optimizer, loss function, and other settings:

In [29]:
from zhijian.data.base import prepare_vision_dataloader
import torch.optim as optim

train_loader, val_loader, num_classes = prepare_vision_dataloader(args, model_args)

optimizer = optim.Adam(
    model.parameters(),
    lr=args.lr,
    weight_decay=args.wd
    )
lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    args.max_epoch,
    eta_min=args.eta_min
    )
criterion = nn.CrossEntropyLoss()

+ Run the code block below to prepare the `trainer` object and start training and testing:

In [31]:
from zhijian.trainers.base import prepare_trainer
trainer = prepare_trainer(
    args,
    model=model,
    model_args=model_args,
    device=device,
    train_loader=train_loader,
    val_loader=val_loader,
    num_classes=num_classes,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler,
    criterion=criterion
    )

trainer.fit()
trainer.test()

Log level set to: INFO
Log files are recorded in: your/log/directory/0718-19-52-36-748
Trainable/total parameters of the model: 0.03M / 38.64M (0.08843%)



      Epoch   GPU Mem.       Time       Loss         LR


        1/5     0.589G     0.1355      4.602      0.001: 100%|██████████| 16.0/16.0 [00:01<00:00, 12.9batch/s]



      Epoch   GPU Mem.       Time      Acc@1      Acc@5


        1/5     0.629G    0.03114      1.871      7.932: 100%|██████████| 157/157 [00:05<00:00, 30.9batch/s] 
***   Best results: [Acc@1: 1.8710191082802548], [Acc@5: 7.931926751592357]



      Epoch   GPU Mem.       Time       Loss         LR


        2/5     0.784G     0.1016      4.538 0.00090451: 100%|██████████| 16.0/16.0 [00:00<00:00, 19.4batch/s]



      Epoch   GPU Mem.       Time      Acc@1      Acc@5


        2/5     0.784G    0.02669      2.498      9.504: 100%|██████████| 157/157 [00:04<00:00, 35.9batch/s] 
***   Best results: [Acc@1: 2.4980095541401273], [Acc@5: 9.504378980891719]



      Epoch   GPU Mem.       Time       Loss         LR


        3/5     0.784G    0.09631      4.488 0.00065451: 100%|██████████| 16.0/16.0 [00:00<00:00, 20.6batch/s]



      Epoch   GPU Mem.       Time      Acc@1      Acc@5


        3/5     0.784G    0.02688      2.379      10.16: 100%|██████████| 157/157 [00:04<00:00, 36.0batch/s] 
***   Best results: [Acc@1: 2.3785828025477707], [Acc@5: 10.161226114649681]



      Epoch   GPU Mem.       Time       Loss         LR


        4/5     0.784G    0.09126       4.45 0.00034549: 100%|██████████| 16.0/16.0 [00:00<00:00, 20.2batch/s]



      Epoch   GPU Mem.       Time      Acc@1      Acc@5


        4/5     0.784G    0.02644      2.468      10.29: 100%|██████████| 157/157 [00:04<00:00, 36.2batch/s] 
***   Best results: [Acc@1: 2.468152866242038], [Acc@5: 10.290605095541402]



      Epoch   GPU Mem.       Time       Loss         LR


        5/5     0.784G     0.0936      4.431 9.5492e-05: 100%|██████████| 16.0/16.0 [00:00<00:00, 20.5batch/s]



      Epoch   GPU Mem.       Time      Acc@1      Acc@5


        5/5     0.784G    0.02706      2.558      10.43: 100%|██████████| 157/157 [00:04<00:00, 35.8batch/s] 
***   Best results: [Acc@1: 2.557722929936306], [Acc@5: 10.429936305732484]



      Epoch   GPU Mem.       Time      Acc@1      Acc@5


        1/5     0.784G    0.02667      2.558      10.43: 100%|██████████| 157/157 [00:04<00:00, 36.0batch/s] 
***   Best results: [Acc@1: 2.557722929936306], [Acc@5: 10.429936305732484]


(2.557722929936306, 10.429936305732484)