# Calls to Nvidia-DLFramework-Inspect

Let's refresh how Nvidia-DLFramework-Inspect with Transformer Engine work together. TransformerEngine layers initialized with `debug=True` have some hook calls inside each of the GEMMs. User can define feature classes or use feature classes provided with TE. File `config.yaml` describes which hooks need to be used to which layers. Nvidia-DLFramework-Inspect combines 3 things: TE training, feature classes and `config.yaml` and takes care of inserting hooks in correct places. This process can be seen in the image below.

<figure align="center">
<img src="./img/api_calls.svg">
    <figcaption> Fig 1: Example of Nvidia-DLFramework-Inspect affecting traing script with 3 TE Linear Layers. There are 2 feature classes defined. There is specification in <i>config.yaml</i>for each layers which features should be used. In "Layer 1" function "process tensor" from the Feature 2 is inserted and in "Layer 2" there is process_tensor() from Feature 1. "Layer 3" is not affected. </figcaption>
</figure>

In this page all calls from TransformerEngine to the Nvidia-DLFramework-Inspect for each GEMM are listed. Order of these calls can be seen on image below.


<figure align="center">
<img src="./img/api_calls2.svg">
    <figcaption> Fig 2: The calls to Nvidia-DLFramework-Inspect done for one GEMM in the Transformer Engine when layer has `debug=True`. `fp8_gemm()` call is used to determine if GEMM will be done in FP8 or in the high precision. If it returns <i>False</i>, then calls related to quantization are not invoked.</figcaption>
</figure>

### process_tensor

This call is invoked before and after every GEMM. It allows to insert some tensor processing. For example feature `FakeCastFp8` use it to emulate casting to FP8, but the tensor is returned in higher precision.

Args:

- `tensor: torch.tensor`,
- `gemm: str` – one of [`fprop`, `dgrad`, `wgrad`],
- `tensor_name: str` – one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`].

Should return:

- `processed_tensor: torch.tensor` – tensor after processing.

### process_quantized_tensor

This call allows to process the tensor after the quantization.

Args:

- `tensor: transformer_engine.pytorch.QuantizedTensor`,
- `gemm: str` – one of [`fprop`, `dgrad`, `wgrad`],
- `tensor_name: str` – one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`].

Should return:

- `processed_tensor: torch.tensor` – tensor after processing.

### fp8_gemm

It is used to determine whether GEMM will be proceeded in FP8 or in high precision. It is called only if forward is inside enabled FP8 autocast.

Args:

- `gemm: str` – one of [`fprop`, `dgrad`, `wgrad`],

Should return:

- `fp_gemm: bool` – tensor after processing.

### save_stats_for_logging_quantized

This call allows to get the tensor after the quantization, it is used by the feature `LogFp8TensorStats`.

Args:

- `tensor: transformer_engine.pytorch.QuantizedTensor`,
- `gemm: str` – one of [`fprop`, `dgrad`, `wgrad`],
- `tensor_name: str` – one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`].

Should return nothing.

### save_stats_for_logging


This call allows to get the stats from the tensor, it is used by the feature `LogTensorStats`.

Args:

- `tensor: torch.tensor`,
- `gemm: str` – one of [`fprop`, `dgrad`, `wgrad`],
- `tensor_name: str` – one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`].

Should return nothing.



