# Porting EXIF-SC implementation to PyTorch

Official code repository: https://github.com/minyoungg/selfconsistency

1. Get PyTorch model building code for a TensorFlow-slim ResNet50 model using [MMdnn](https://github.com/Microsoft/MMdnn/blob/master/docs/tf2pytorch.md)

```
pip install mmdnn
mmdownload -f tensorflow -n resnet_v2_50
mmtoir -f tensorflow -n imagenet_resnet_v2_50.ckpt.meta -w imagenet_resnet_v2_50.ckpt --dstNode MMdnn_Output -o converted
mmtocode -f pytorch -n converted.pb -w converted.npy -d converted_pytorch.py -dw converted_pytorch.npy
```

2. Download the EXIF-SC model checkpoint from the [official repo](https://github.com/minyoungg/selfconsistency)

3. Examine the variables in the TensorFlow checkpoint `exif_final.ckpt`.
    - Extract all relevant weights, and make any necessary modifications in order to load them into PyTorch layers. 
    - Modify the model building code `converted_pytorch.py` in order to load those weights into the PyTorch model.

In [None]:
import tensorflow as tf
from tqdm import tqdm

In [None]:
ckpt_path = 'ckpt/exif_final/exif_final.ckpt'

In [None]:
tf_vars = tf.train.list_variables(ckpt_path)

In [None]:
for name, shape in tf_vars:
    print(name, shape)

In [None]:
# Remove unncessary variables
# Modify the weights into a format suitable for PyTorch
weights_dict = {}

for name, _ in tqdm(tf_vars):

    name_split = name.split('/')
    weight_type = name_split[-1]
    
    # Exclude unnecessary variables
    if weight_type in ['beta1_power', 'beta2_power', 'Adam', 'Adam_1']:
        continue

    weight_name = '/'.join(name_split[:-1])

    weights = tf.train.load_variable(ckpt_path, name)
    
    # Transpose CNN weights
    # [H, W, C, F] -> [F, C, H, W]
    if len(weights.shape) == 4:
        weights = np.transpose(weights, (3, 2, 0, 1))
    # Tranpose linear matrices
    if len(weights.shape) == 2:
        weights = np.transpose(weights, (1, 0))

    if weight_name not in weights_dict:
        weights_dict[weight_name] = {}
    weights_dict[weight_name][weight_type] = weights

In [None]:
# Save the weights into a separate file
np.save('ckpt/resnet_50_pt/exif_final.npy', weights_dict)