## load library

In [496]:
#python

import os

from datetime import datetime
import yaml
from typing import Any, Dict, Tuple, Union, List

import random

from collections import defaultdict

from tqdm import tqdm,notebook

import copy 

import json

#sklearn
from sklearn.metrics import f1_score

#numpy

import numpy as np

#pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder

#musco
from musco.pytorch import CompressorVBMF, CompressorPR, CompressorManual
from flopco import FlopCo
from musco.pytorch.compressor.rank_selection.estimator import estimate_rank_for_compression_rate, estimate_vbmf_ranks


#baseline

import warnings

warnings.filterwarnings('ignore')

import sys

sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname('/opt/ml/code/src/'))))

import src

from src.dataloader import create_dataloader
from src.loss import CustomCriterion
from src.model import Model
from src.trainer import TorchTrainer
from src.utils.common import get_label_counts, read_yaml
from src.utils.macs import calc_macs
from src.utils.torch_utils import check_runtime, model_info
from src.augmentation.policies import simple_augment_test
from src.utils.inference_utils import run_model
from src.utils.torch_utils import save_model

## load base model

In [352]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [558]:
model_config = read_yaml(cfg="exp/2021-06-14_15-59-59/model.yml")
data_config = read_yaml(cfg="exp/2021-06-14_15-59-59/data.yml")

model_config = read_yaml(cfg=model_config)
data_config = read_yaml(cfg=data_config)

In [559]:
model_instance = Model(model_config,verbose=True)

idx |   n |     params |          module |            arguments |   in_channel |   out_channel
----------------------------------------------------------------------------------------------
  0 |   1 |        464 |            Conv | [16, 3, 2, None, 1, 'Hardswish'] |            3           16
  1 |   1 |      2,336 |            Conv | [16, 3, 2, None, 1, 'ReLU'] |           16           16
  2 |   1 |      2,336 |            Conv | [16, 3, 2, None, 1, 'ReLU'] |           16           16
  3 |   5 |      1,680 |          DWConv | [112, 1, 2, None, 'ReLU'] |           16          112
  4 |   1 |     12,096 |          DWConv | [448, 5, 1, None, 'ReLU'] |          112          448
  5 |   1 |    403,200 |            Conv |          [896, 1, 1] |          448          896
  6 |   1 |          0 |   GlobalAvgPool |                   [] |          896          896
  7 |   1 |      8,082 |       FixedConv | [9, 1, 1, None, 1, None] |          896            9
Model Summary: 47 layers, 430,194 

In [398]:
model_path = 'exp/0.5177_100epoch_1120/best.pt'

if os.path.isfile(model_path):
    model_instance.model.load_state_dict(torch.load(model_path, map_location=device))

In [560]:
model_instance.model = model_instance.model.to(device)

## calculate original model statistics

In [473]:
model_instance.model = model_instance.model.eval()
model_stats = FlopCo(model_instance.model, device = device)

In [474]:
model_stats.total_flops,  model_stats.relative_flops

(191406208,
 defaultdict(None,
             {'0.conv': 0.056623116424729544,
              '1.conv': 0.07549748856630606,
              '2.conv': 0.018874372141576515,
              '3.0.conv': 0.0009175042013266362,
              '3.1.conv': 0.0009175042013266362,
              '3.2.conv': 0.0009175042013266362,
              '3.3.conv': 0.0009175042013266362,
              '3.4.conv': 0.00022937605033165905,
              '4.conv': 0.022937605033165904,
              '5.conv': 0.822083764388666,
              '7.conv': 8.42605899177523e-05}))

## find compress layer

In [475]:
all_layer = [k for k in model_stats.flops.keys()]

In [476]:
lnames_to_compress = [k for k in model_stats.flops.keys() if\
                      model_stats.ltypes[k]['type'] == nn.Conv2d and\
                      model_stats.ltypes[k]['groups'] == 1
                     ]

In [477]:
len(all_layer), all_layer

(11,
 ['0.conv',
  '1.conv',
  '2.conv',
  '3.0.conv',
  '3.1.conv',
  '3.2.conv',
  '3.3.conv',
  '3.4.conv',
  '4.conv',
  '5.conv',
  '7.conv'])

In [478]:
len(lnames_to_compress), lnames_to_compress

(5, ['0.conv', '1.conv', '2.conv', '5.conv', '7.conv'])

## find max rank for compress layer

In [479]:
wf = 0.55
nx = 20. # compression rate

max_ranks = defaultdict()

#lnames_compress_me = []

for mname, m in model_instance.model.named_modules():
    if mname in lnames_to_compress:
        lname = mname
        _, cin, _, _ = model_stats.input_shapes[lname][0]
        _, cout, _, _ = model_stats.output_shapes[lname][0]
        kernel_size = model_stats.ltypes[lname]['kernel_size']

        tensor_shape = (cout, cin, *kernel_size)
        r_pr = estimate_rank_for_compression_rate(tensor_shape, rate = nx, key = 'tucker2')
        
        #r_vbmf = estimate_vbmf_ranks(m.weight.data[:, :, 0, 0])
        
        max_ranks[lname] = r_pr

        #print('\n', lname, tensor_shape, r_pr, r_vbmf)
        print('\n', lname, tensor_shape, r_pr)
        #if r_pr > r_vbmf:
            #lnames_compress_me.append(lname)
            #print('===== COMPRESS ME ===== {} times\n'.format(r_pr/r_vbmf))
        #else:
            #print('===== DO NOT COMPRESS ME =====\n')
    else:
        if mname in all_layer:
            max_ranks[mname] = None


 0.conv (16, 3, 3, 3) (8, 2)

 1.conv (16, 16, 3, 3) (2, 2)

 2.conv (16, 16, 3, 3) (2, 2)

 5.conv (896, 448, 1, 1) (16, 10)

 7.conv (9, 896, 1, 1) (2, 2)


In [480]:
max_ranks

defaultdict(None,
            {'0.conv': (8, 2),
             '1.conv': (2, 2),
             '2.conv': (2, 2),
             '3.0.conv': None,
             '3.1.conv': None,
             '3.2.conv': None,
             '3.3.conv': None,
             '3.4.conv': None,
             '4.conv': None,
             '5.conv': (16, 10),
             '7.conv': (2, 2)})

## define basic compression function

In [87]:
def calibrate(model, device, train_loader, max_iters = 1000,
              freeze_lnames = None):

    model.to(device).train()
    for pname, p in model.named_parameters():
        
        if pname.strip('.weight').strip('.bias')  in freeze_lnames:
            p.requires_grad = False

    with torch.no_grad():
        for i, (data, _) in notebook.tqdm(enumerate(train_loader)):
            _ = model(data.to(device))

            if i > max_iters:
                break

            del data
            torch.cuda.empty_cache()
            
    model.eval()
    return model

In [120]:
## github test function

def test(model, device, test_loader):
    model.eval()
    #test_loss = 0
    correct = 0
    #criterion = nn.CrossEntropyLoss()
    
    preds = []
    gt = []
    
    num_classes=9
    label_list = [i for i in range(num_classes)]
    
    with torch.no_grad():
        for i,(data, target) in notebook.tqdm(enumerate(test_loader)):
            data, target = data.to(device), target.to(device)

            
            output = model(data)
            
            
            #test_loss += criterion(output, target).item() # sum up batch loss
            _, pred_f1 = torch.max(output, 1)
            
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
            
            preds += pred_f1.to("cpu").tolist()
            gt += target.to("cpu").tolist()
            
            
            
    #test_loss /= len(test_loader.dataset)
    
    accuracy = 100. * correct / len(test_loader.dataset)
    
    f1 = f1_score(
        y_true=gt, y_pred=preds, labels=label_list, average="macro", zero_division=0
    )
  
    print('\nTest set: Accuracy: {}/{} ({:.2f}%) , Test f1: {:.2f} \n'.format(
        correct, len(test_loader.dataset),
        accuracy, f1))
    
    
    del data
    torch.cuda.empty_cache()
    
    return  f1

In [89]:
#baseline test function

@torch.no_grad()
def test(model,test_dataloader):
    """Test model.

    Args:
        test_dataloader: test data loader module which is a iterator that returns (data, labels)

    Returns:
        loss, f1, accuracy
    """

    #n_batch = _get_n_batch_from_dataloader(test_dataloader)

    #running_loss = 0.0
    preds = []
    gt = []
    correct = 0
    total = 0

    num_classes = 9
    label_list = [i for i in range(num_classes)]

    pbar = notebook.tqdm(enumerate(test_dataloader), total=len(test_dataloader))
    model.to(device)
    model.eval()
    for batch, (data, labels) in pbar:
        data, labels = data.to(device), labels.to(device)

        
        outputs = model(data)
        outputs = torch.squeeze(outputs)
        #running_loss += criterion(outputs, labels).item()

        _, pred = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (pred == labels).sum().item()

        preds += pred.to("cpu").tolist()
        gt += labels.to("cpu").tolist()
        
    #loss = running_loss / len(test_dataloader)
    accuracy = correct / total
    f1 = f1_score(
        y_true=gt, y_pred=preds, labels=label_list, average="macro", zero_division=0
    )
    
    print('\nTest set: Accuracy: {}/{} ({:.2f}%) , Test f1: {:.2f} \n'.format(
        correct, total,
        accuracy, f1))
    
    return f1

## SVD compression

In [27]:
train_dl, val_dl, test_dl = create_dataloader(data_config)

f1_init = 0.0
eps = 0.01


min_ranks = {k:2 for k in max_ranks.keys()}

curr_ranks = copy.deepcopy(max_ranks)


for idx,lname in enumerate(lnames_compress_me):
    print(lname)
    
    curr = max_ranks[lname]
    curr_max = max_ranks[lname]
    curr_min = min_ranks[lname]
    
    for i in range(4):
        print(curr_min, curr, curr_max)
        
        ranks =  {k:None for k in model_stats.flops.keys() if k not in [lname]}
        ranks[lname] = curr
        
    
        compressor = CompressorManual(model_instance.model, model_stats,\
                              ranks = ranks,
                              ft_every = 1,\
                                      nglobal_compress_iters = 1)

        print("\n Compress")
        compressor.compression_step()

        print("\n Calibrate")
        compressor.model = calibrate(compressor.compressed_model,\
                                     device, train_dl,\
                                     freeze_lnames = lnames_to_compress[:idx])

        print("\n Test")
        f1 = test(compressor.compressed_model, val_dl)

        if f1 + eps < f1_init:
            if i == 0:
                print('Bad layer to compress')
                break
            else:
                curr_min = curr
                curr = curr + (curr_max - curr)//2
        else:
            curr_max = curr
            curr = curr - (curr - curr_min)//2
            
        macs = calc_macs(compressor.compressed_model, (3, data_config["IMG_SIZE"], data_config["IMG_SIZE"]))
        print(f"macs: {macs}")
    
    model_instance.model = compressor.compressed_model
            
        
    print('\n Fine-tune')

4.0.conv.7
2 38 38

 Compress
4.0.conv.7 svd

 Calibrate


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



 Test


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=255.0), HTML(value='')))



Test set: Accuracy: 5136/8148 (0.63%) , Test f1: 0.53 

macs: 11240834.0
2 20 38

 Compress
4.0.conv.7 svd

 Calibrate


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



 Test


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=255.0), HTML(value='')))



Test set: Accuracy: 5140/8148 (0.63%) , Test f1: 0.53 

macs: 11211026.0
2 11 20

 Compress
4.0.conv.7 svd

 Calibrate


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



 Test


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=255.0), HTML(value='')))



Test set: Accuracy: 5110/8148 (0.63%) , Test f1: 0.51 

macs: 11196122.0
2 7 11

 Compress
4.0.conv.7 svd

 Calibrate


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



 Test


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=255.0), HTML(value='')))



Test set: Accuracy: 4908/8148 (0.60%) , Test f1: 0.44 

macs: 11189498.0

 Fine-tune
4.1.conv.0
2 99 99

 Compress
4.1.conv.0 svd

 Calibrate


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



 Test


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=255.0), HTML(value='')))



Test set: Accuracy: 4888/8148 (0.60%) , Test f1: 0.44 

macs: 11186618.0
2 51 99

 Compress
4.1.conv.0 svd

 Calibrate


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



 Test


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=255.0), HTML(value='')))



Test set: Accuracy: 4873/8148 (0.60%) , Test f1: 0.43 

macs: 10937786.0
2 27 51

 Compress
4.1.conv.0 svd

 Calibrate


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



 Test


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=255.0), HTML(value='')))



Test set: Accuracy: 4852/8148 (0.60%) , Test f1: 0.43 

macs: 10813370.0
2 15 27

 Compress
4.1.conv.0 svd

 Calibrate


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



 Test


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=255.0), HTML(value='')))



Test set: Accuracy: 4840/8148 (0.59%) , Test f1: 0.42 

macs: 10751162.0

 Fine-tune
4.1.conv.5.fc1
2 89 89

 Compress
4.1.conv.5.fc1 svd

 Calibrate


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



 Test


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=255.0), HTML(value='')))



Test set: Accuracy: 4860/8148 (0.60%) , Test f1: 0.42 

macs: 10750826.0
2 46 89

 Compress
4.1.conv.5.fc1 svd

 Calibrate


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



 Test


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=255.0), HTML(value='')))



Test set: Accuracy: 4833/8148 (0.59%) , Test f1: 0.42 

macs: 10726746.0
2 24 46

 Compress
4.1.conv.5.fc1 svd

 Calibrate


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



 Test


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=255.0), HTML(value='')))



Test set: Accuracy: 4864/8148 (0.60%) , Test f1: 0.43 

macs: 10714426.0
2 13 24

 Compress
4.1.conv.5.fc1 svd

 Calibrate


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



 Test


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=255.0), HTML(value='')))



Test set: Accuracy: 4841/8148 (0.59%) , Test f1: 0.43 

macs: 10708266.0

 Fine-tune
4.1.conv.5.fc2
2 89 89

 Compress
4.1.conv.5.fc2 svd

 Calibrate


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



 Test


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=255.0), HTML(value='')))



Test set: Accuracy: 4846/8148 (0.59%) , Test f1: 0.43 

macs: 10707930.0
2 46 89

 Compress
4.1.conv.5.fc2 svd

 Calibrate


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



 Test


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=255.0), HTML(value='')))



Test set: Accuracy: 4838/8148 (0.59%) , Test f1: 0.42 

macs: 10683850.0
2 24 46

 Compress
4.1.conv.5.fc2 svd

 Calibrate


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



 Test


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=255.0), HTML(value='')))



Test set: Accuracy: 4810/8148 (0.59%) , Test f1: 0.42 

macs: 10671530.0
2 13 24

 Compress
4.1.conv.5.fc2 svd

 Calibrate


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



 Test


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=255.0), HTML(value='')))



Test set: Accuracy: 4826/8148 (0.59%) , Test f1: 0.41 

macs: 10665370.0

 Fine-tune
4.1.conv.7
2 99 99

 Compress
4.1.conv.7 svd

 Calibrate


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



 Test


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=255.0), HTML(value='')))



Test set: Accuracy: 4826/8148 (0.59%) , Test f1: 0.42 

macs: 10662490.0
2 51 99

 Compress
4.1.conv.7 svd

 Calibrate


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



 Test


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=255.0), HTML(value='')))



Test set: Accuracy: 4840/8148 (0.59%) , Test f1: 0.42 

macs: 10413658.0
2 27 51

 Compress
4.1.conv.7 svd

 Calibrate


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



 Test


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=255.0), HTML(value='')))



Test set: Accuracy: 4828/8148 (0.59%) , Test f1: 0.42 

macs: 10289242.0
2 15 27

 Compress
4.1.conv.7 svd

 Calibrate


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



 Test


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=255.0), HTML(value='')))



Test set: Accuracy: 4805/8148 (0.59%) , Test f1: 0.42 

macs: 10227034.0

 Fine-tune
5.conv
2 106 106

 Compress
5.conv svd

 Calibrate


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



 Test


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=255.0), HTML(value='')))



Test set: Accuracy: 4793/8148 (0.59%) , Test f1: 0.42 

macs: 10222426.0
2 54 106

 Compress
5.conv svd

 Calibrate


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



 Test


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=255.0), HTML(value='')))



Test set: Accuracy: 4820/8148 (0.59%) , Test f1: 0.42 

macs: 9863002.0
2 28 54

 Compress
5.conv svd

 Calibrate


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



 Test


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=255.0), HTML(value='')))



Test set: Accuracy: 4798/8148 (0.59%) , Test f1: 0.41 

macs: 9683290.0
2 15 28

 Compress
5.conv svd

 Calibrate


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



 Test


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=255.0), HTML(value='')))



Test set: Accuracy: 4819/8148 (0.59%) , Test f1: 0.42 

macs: 9593434.0

 Fine-tune


## tucker2 decomposition compression

In [324]:
for lname in all_layer:
    
    ranks =  {k:None for k in model_stats.flops.keys() if k not in [lname]}
    
    ranks[lname]=max_ranks[lname]
    
    print(ranks)
    
    if ranks[lname] == None:
        continue
    

    compressor = CompressorManual(model_instance.model, model_stats,conv2d_nn_decomposition='tucker2',ranks = ranks, ft_every = 1, nglobal_compress_iters =1)
    
    compressor.decompositions = {k:'tucker2' for k in compressor.decompositions.keys()}

    #while not compressor.done:
    print("\n Compress")
    compressor.compression_step()

    #print("\n Calibrate")
    #compressor.model = calibrate(compressor.compressed_model, device, train_dl,freeze_lnames = lnames_to_compress[:idx])

    compressor.compressed_model = compressor.compressed_model.to(device)

    macs = calc_macs(compressor.compressed_model, (3, data_config["IMG_SIZE"], data_config["IMG_SIZE"]))
    print(f"macs: {macs}")

    #print("\n Test")
    #test(compressor.compressed_model, device, val_dl)

    print('\n Fine-tune')
    
    ranks[lname]=None
    
    print(ranks)

    model_instance.model = compressor.compressed_model

{'0.1.conv': None, '0.2.conv': None, '1.0.conv.0.0': None, '1.0.conv.1.0': None, '1.0.conv.2': None, '2.0.conv.0.0': None, '2.0.conv.1': None, '2.1.conv.0.0': None, '2.1.conv.1': None, '2.2.conv.0.0': None, '2.2.conv.1': None, '2.3.conv.0.0': None, '2.3.conv.1': None, '3.0.conv.0.0': None, '3.0.conv.1.0': None, '3.0.conv.2': None, '3.1.conv.0.0': None, '3.1.conv.1.0': None, '3.1.conv.2': None, '3.2.conv.0.0': None, '3.2.conv.1.0': None, '3.2.conv.2': None, '3.3.conv.0.0': None, '3.3.conv.1.0': None, '3.3.conv.2': None, '3.4.conv.0.0': None, '3.4.conv.1.0': None, '3.4.conv.2': None, '4.0.conv.0': None, '4.0.conv.3': None, '4.0.conv.5.fc1': None, '4.0.conv.5.fc2': None, '4.0.conv.7': None, '4.1.conv.0': None, '4.1.conv.3': None, '4.1.conv.5.fc1': None, '4.1.conv.5.fc2': None, '4.1.conv.7': None, '5.conv': None, '7.conv': None, '0.0.conv': (8, 2)}

 Compress
0.0.conv tucker2
macs: 9823154.0

 Fine-tune
{'0.1.conv': None, '0.2.conv': None, '1.0.conv.0.0': None, '1.0.conv.1.0': None, '1.0.c

TypeError: unsupported operand type(s) for %: 'tuple' and 'int'

In [562]:
for lname in all_layer:
    ranks =  {k:None for k in all_layer}
    ranks[lname] = max_ranks[lname]
    
    if ranks[lname] == None:
        continue

    compressor = CompressorManual(model_instance.model, model_stats,ranks = ranks, ft_every = 1, conv2d_nn_decomposition='tucker2', nglobal_compress_iters = 1)
    
    compressor.decompositions = {k:'tucker2' for k in compressor.decompositions.keys()}


    #while not compressor.done:
    print("\n Compress")
    compressor.compression_step()

    #print("\n Calibrate")
    #compressor.model = calibrate(compressor.compressed_model, device, train_dl,freeze_lnames = lnames_to_compress[:idx])

    compressor.compressed_model = compressor.compressed_model.to(device)

    macs = calc_macs(compressor.compressed_model, (3, data_config["IMG_SIZE"], data_config["IMG_SIZE"]))
    print(f"macs: {macs}")

    #print("\n Test")
    #test(compressor.compressed_model, device, val_dl)

    print('\n Fine-tune')

    model_instance.model = compressor.compressed_model


 Compress
0.conv tucker2
macs: 7827602.0

 Fine-tune

 Compress
1.conv tucker2
macs: 7287954.0

 Fine-tune

 Compress
2.conv tucker2
macs: 7153042.0

 Fine-tune

 Compress
5.conv tucker2
macs: 1034130.0

 Fine-tune

 Compress
7.conv tucker2
macs: 1027880.0

 Fine-tune


## calculate compression model macs

In [277]:
model_instance.model

Sequential(
  (0): Sequential(
    (0): DWConv(
      (conv): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=[1], bias=False)
      (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU()
    )
    (1): DWConv(
      (conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=[1], groups=16, bias=False)
      (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU()
    )
    (2): DWConv(
      (conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=[1], groups=16, bias=False)
      (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU()
    )
  )
  (1): Sequential(
    (0): InvertedResidualv2(
      (conv): Sequential(
        (0): ConvBNReLU(
          (0): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
 

In [467]:
macs = calc_macs(model_instance.model, (3, data_config["IMG_SIZE"], data_config["IMG_SIZE"]))
print(f"macs: {macs}")

macs: 65960.0


## inference compression model

In [81]:
dst='/opt/ml/'
data_config = 'exp/0.6420_150epoch_5800/data.yml'
img_root = 'input/data/test'

In [77]:
#imagefolder function

CLASSES = ['Battery', 'Clothing', 'Glass', 'Metal', 'Paper', 'Paperpack', 'Plastic', 'Plasticbag', 'Styrofoam']

class CustomImageFolder(ImageFolder):
    """ImageFolder with filename."""

    def __getitem__(self, index):
        img_gt = super(CustomImageFolder, self).__getitem__(index)
        fdir = self.imgs[index][0]
        fname = fdir.rsplit(os.path.sep, 1)[-1]
        return (img_gt + (fname,))

In [78]:
#prepare dataloader

def get_dataloader(img_root: str, data_config: str) -> DataLoader:
    """Get dataloader.

    Note:
	Don't forget to set normalization.
    """
    # Load yaml
    data_config = read_yaml(data_config)

    transform_test_args = data_confg["AUG_TEST_PARAMS"] if data_config.get("AUG_TEST_PARAMS") else None
    # Transformation for test
    transform_test = getattr(
        __import__("src.augmentation.policies", fromlist=[""]),
        data_config["AUG_TEST"],
    )(dataset=data_config["DATASET"], img_size=data_config["IMG_SIZE"])

    dataset = CustomImageFolder(root=img_root, transform=transform_test)
    dataloader = DataLoader(
	dataset=dataset,
	batch_size=1,
	num_workers=8
    )
    return dataloader

In [79]:
@torch.no_grad()
def inference(model, dataloader, dst_path: str):
    result = {}
    model = model.to(device)
    model.eval()
    submission_csv = {}
    for img, _, fname in dataloader:
        img = img.to(device)
        pred, enc_data = run_model(model, img)
        pred = torch.argmax(pred)
        submission_csv[fname[0]] = CLASSES[int(pred.detach())]

    result["macs"] = enc_data
    result["submission"] = submission_csv
    j = json.dumps(result, indent=4)
    save_path = os.path.join(dst_path, 'submission.csv')
    with open(save_path, 'w') as outfile:
        json.dump(result, outfile)


In [82]:
# prepare datalaoder
dataloader = get_dataloader(img_root=img_root, data_config=data_config)

# inference
inference(model_instance.model, dataloader, dst)



## fine-tuning compression model

In [32]:
log_dir = os.path.join("exp", datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
os.makedirs(log_dir, exist_ok=True)

In [425]:
log_dir='exp/2021-06-14_15-59-59'

In [None]:
data_config['AUG_TRAIN_PARAMS']['n_select'] = 0

In [None]:
train_dl, val_dl, test_dl = create_dataloader(data_config)

In [568]:
train_path = os.path.join(data_config["DATA_PATH"], "train")
model_path = os.path.join(log_dir, "best.pt")

# Create optimizer, scheduler, criterion
optimizer = torch.optim.SGD(
    model_instance.model.parameters(), lr=0.1, momentum=0.9
)


#scheduler = torch.optim.lr_scheduler.OneCycleLR(
    #optimizer=optimizer,
    #max_lr=data_config["INIT_LR"],
    #steps_per_epoch=len(train_dl),
    #epochs=data_config["EPOCHS"],
    #pct_start=0.05,
#)

scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10,30,50,70,90,110,130,150,170,190], gamma=0.5)

criterion = CustomCriterion(
    samples_per_cls=get_label_counts(train_path)
    if data_config["DATASET"] == "TACO"
    else None,
    device=device,
    #loss_type="weighted"
    loss_type="customloss"
    #loss_type="label_smoothing"
)


# Amp loss scaler
scaler = (
    torch.cuda.amp.GradScaler() if data_config['FP16'] and device != torch.device("cpu") else None
)
#scaler=None

In [483]:
@torch.no_grad()
def test(model,test_dataloader):
    """Test model.

    Args:
        test_dataloader: test data loader module which is a iterator that returns (data, labels)

    Returns:
        loss, f1, accuracy
    """

    #n_batch = _get_n_batch_from_dataloader(test_dataloader)

    running_loss = 0.0
    preds = []
    gt = []
    correct = 0
    total = 0

    num_classes = 9
    label_list = [i for i in range(num_classes)]

    pbar = notebook.tqdm(enumerate(test_dataloader), total=len(test_dataloader))
    model.to(device)
    model.eval()
    for batch, (data, labels) in pbar:
        data, labels = data.to(device), labels.to(device)

        if scaler:
            with torch.cuda.amp.autocast():
                outputs = model(data)
        else:
            outputs = model(data)
        outputs = torch.squeeze(outputs)
        running_loss += criterion(outputs, labels).item()

        _, pred = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (pred == labels).sum().item()

        preds += pred.to("cpu").tolist()
        gt += labels.to("cpu").tolist()
        pbar.update()
        pbar.set_description(
            f" Val: {'':5} Loss: {(running_loss / (batch + 1)):.3f}, "
            f"Acc: {(correct / total) * 100:.2f}% "
            f"F1(macro): {f1_score(y_true=gt, y_pred=preds, labels=label_list, average='macro', zero_division=0):.2f}"
        )
    loss = running_loss / len(test_dataloader)
    accuracy = correct / total
    f1 = f1_score(
        y_true=gt, y_pred=preds, labels=label_list, average="macro", zero_division=0
    )
    
    return loss, f1, accuracy

In [569]:
#basic training

best_test_acc = -1.0
best_test_f1 = -1.0

num_classes = 9

label_list = [i for i in range(num_classes)]

for epoch in range(n_epoch):
    running_loss, correct, total = 0.0, 0, 0
    preds, gt = [], []
    pbar = notebook.tqdm(enumerate(train_dl), total=len(train_dl))
    model_instance.model.train()
    for batch, (data, labels) in pbar:
        
        data, labels = data.to(device), labels.to(device)

        if scaler:
            with torch.cuda.amp.autocast():
                outputs = model_instance.model(data)
        else:
            outputs = model_instance.model(data)
        outputs = torch.squeeze(outputs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()

        if scaler:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
        scheduler.step()

        _, pred = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (pred == labels).sum().item()
        preds += pred.to("cpu").tolist()
        gt += labels.to("cpu").tolist()

        running_loss += loss.item()
        pbar.update()
        pbar.set_description(
            f"Train: [{epoch + 1:03d}] "
            f"Loss: {(running_loss / (batch + 1)):.3f}, "
            f"Acc: {(correct / total) * 100:.2f}% "
            f"F1(macro): {f1_score(y_true=gt, y_pred=preds, labels=label_list, average='macro', zero_division=0):.2f}"
        )
    pbar.close()

    _, test_f1, test_acc = test(
        model=model_instance.model, test_dataloader=val_dl
    )
    if best_test_f1 > test_f1:
        continue
    best_test_acc = test_acc
    best_test_f1 = test_f1
    print(f"Model saved. Current best test f1: {best_test_f1:.3f}")
    save_model(
        model=model_instance.model,
        path=model_path,
        data=data,
        device=device,
    )


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))


Model saved. Current best test f1: 0.118


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))


Model saved. Current best test f1: 0.118


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))


Model saved. Current best test f1: 0.118


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))


Model saved. Current best test f1: 0.118


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))


Model saved. Current best test f1: 0.119


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))


Model saved. Current best test f1: 0.119


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))




Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f1e175efaf0>
Traceback (most recent call last):
  File "/miniconda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1101, in __del__
    self._shutdown_workers()
Exception ignored in:   File "/miniconda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1075, in _shutdown_workers
    

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))

w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)<function _MultiProcessingDataLoaderIter.__del__ at 0x7f1e175efaf0>

Traceback (most recent call last):
  File "/miniconda/lib/python3.8/multiprocessing/process.py", line 147, in join
  File "/miniconda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1101, in __del__
        self._shutdown_workers()assert self._parent_pid == os.getpid(), 'can only join a child process'

  File "/miniconda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1075, in _shutdown_workers
AssertionError    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL): 
can only join a child process  File "/miniconda/lib/python3.8/multiprocessing/process.py", line 147, in join

    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f1e175efaf0><function _MultiProcessingDataLoaderIte




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))




KeyboardInterrupt: 

In [514]:
def mixup_data(x, y, alpha=1.0, use_cuda=False):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

In [524]:
#mixup training

best_test_acc = -1.0
best_test_f1 = -1.0

num_classes = 9

label_list = [i for i in range(num_classes)]

for epoch in range(n_epoch):
    running_loss, correct, total = 0.0, 0, 0
    preds, gt = [], []
    pbar = notebook.tqdm(enumerate(train_dl), total=len(train_dl))
    model_instance.model.train()
    for batch, (data, labels) in pbar:
        
        data,label_a,label_b,lam = mixup_data(data,labels,alpha=1.0)
        
        data, labels, label_a, label_b = data.to(device), labels.to(device), label_a.to(device), label_b.to(device)
        

        if scaler:
            with torch.cuda.amp.autocast():
                outputs = model_instance.model(data)
        else:
            outputs = model_instance.model(data)
        outputs = torch.squeeze(outputs)
        #loss = criterion(outputs, labels)
        loss = criterion(outputs, label_a) * lam + criterion(outputs, label_b)* (1-lam)

        optimizer.zero_grad()

        if scaler:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
        scheduler.step()

        _, pred = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (pred == labels).sum().item()
        preds += pred.to("cpu").tolist()
        gt += labels.to("cpu").tolist()

        running_loss += loss.item()
        pbar.update()
        pbar.set_description(
            f"Train: [{epoch + 1:03d}] "
            f"Loss: {(running_loss / (batch + 1)):.3f}, "
            f"Acc: {(correct / total) * 100:.2f}% "
            f"F1(macro): {f1_score(y_true=gt, y_pred=preds, labels=label_list, average='macro', zero_division=0):.2f}"
        )
    pbar.close()

    _, test_f1, test_acc = test(
        model=model_instance.model, test_dataloader=val_dl
    )
    if best_test_f1 > test_f1:
        continue
    best_test_acc = test_acc
    best_test_f1 = test_f1
    print(f"Model saved. Current best test f1: {best_test_f1:.3f}")
    save_model(
        model=model_instance.model,
        path=model_path,
        data=data,
        device=device,
    )

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))


Model saved. Current best test f1: 0.104


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))


Model saved. Current best test f1: 0.119


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))


Model saved. Current best test f1: 0.129


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))


Model saved. Current best test f1: 0.149


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=170.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=509.0), HTML(value='')))




KeyboardInterrupt: 