## Load library

In [None]:
#python

import json
import copy
import random
import yaml

from datetime import datetime
from typing import Any, Dict, Tuple, Union, List

from collections import defaultdict
from tqdm import tqdm,notebook

#automl optuna
import optuna

#sklearn

from sklearn.metrics import f1_score

#numpy

import numpy as np

#pytorch

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder

#baseline

import warnings

warnings.filterwarnings('ignore')

import os
import sys

sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname('/opt/ml/code/src/'))))

import src

from src.dataloader import create_dataloader
from src.loss import CustomCriterion
from src.model import Model
from src.trainer import TorchTrainer
from src.utils.common import get_label_counts, read_yaml
from src.utils.macs import calc_macs
from src.utils.torch_utils import check_runtime, model_info, save_model
from src.augmentation.policies import simple_augment_test
from src.utils.inference_utils import run_model


from train import train

#musco

from musco.pytorch import CompressorVBMF, CompressorPR, CompressorManual
from flopco import FlopCo
from musco.pytorch.compressor.rank_selection.estimator import estimate_rank_for_compression_rate, estimate_vbmf_ranks



## hyperparameter

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## fixed seed

In [None]:
#torch seed
torch.manual_seed(30)
torch.cuda.manual_seed(30)

#numpy seed
np.random.seed(30)

#python seed
random.seed(30)

## load base model

In [None]:
model_config = read_yaml(cfg="exp/0.5177_100epoch_1120/model.yml")
data_config = read_yaml(cfg="exp/0.5177_100epoch_1120/data.yml")

model_config = read_yaml(cfg=model_config)
data_config = read_yaml(cfg=data_config)

In [None]:
model_instance = Model(model_config,verbose=True)

idx |   n |     params |          module |            arguments |   in_channel |   out_channel
----------------------------------------------------------------------------------------------
  0 |   3 |        816 |          DWConv | [16, 3, 2, None, 'ReLU'] |            3           16
  1 |   1 |      2,016 | InvertedResidualv2 |           [32, 2, 2] |           16           32
  2 |   4 |      2,288 | InvertedResidualv2 |           [16, 1, 2] |           32           16
  3 |   5 |      7,360 | InvertedResidualv2 |           [16, 2, 2] |           16           16
  4 |   2 |    240,656 | InvertedResidualv3 | [5, 3.5, 128, 1, 1, 2] |           16          128
  5 |   1 |     83,200 |            Conv |          [640, 1, 1] |          128          640
  6 |   1 |          0 |   GlobalAvgPool |                   [] |          640          640
  7 |   1 |      5,778 |       FixedConv | [9, 1, 1, None, 1, None] |          640            9
Model Summary: 161 layers, 342,114 parameters, 342,1

In [None]:
model_path = 'exp/0.5177_100epoch_1120/best.pt'

if os.path.isfile(model_path):
    model_instance.model.load_state_dict(torch.load(model_path, map_location=device))

In [None]:
#calculate original_macs

original_macs = calc_macs(model_instance.model, (3, data_config["IMG_SIZE"], data_config["IMG_SIZE"]))
print(f"macs: {original_macs}")

macs: 11242418.0


## create register_buffer

In [None]:
for name, param in model_instance.model.named_modules():
    if isinstance(param, nn.Conv2d):
        param.register_buffer('rank', torch.tensor([0.5,0.5]))# rank in, out   

## calculate model statistic

In [None]:
model_instance.model = model_instance.model.to(device)
model_instance.model = model_instance.model.eval()
model_stats = FlopCo(model_instance.model, img_size=(1,3,data_config['IMG_SIZE'],data_config['IMG_SIZE']), device = device)

In [None]:
model_stats.total_flops,  model_stats.relative_flops

## find model conv layer for compression

In [None]:
all_layer = [k for k in model_stats.flops.keys()]
all_layer

In [None]:
lnames_to_compress = [k for k in model_stats.flops.keys() if\
                      model_stats.ltypes[k]['type'] == nn.Conv2d and\
                      model_stats.ltypes[k]['groups'] == 1
                     ]
lnames_to_compress

## define compression function

In [None]:
def compression(lnames_to_compress, device, model_instance):
    for lname in lnames_to_compress:

        ranks =  {k:None for k in all_layer}

        for name, param in model_instance.model.named_modules():
            if lname == name:
                if param.groups == 1:
                    tensor_rank = getattr(param, "rank")
                    rank = [int(r * param.weight.shape[i]) for i, r in enumerate(tensor_rank)]
                    ranks[lname] = [max(r, 2) for r in rank]
                    break

        if ranks[lname] == None:
            continue

        compressor = CompressorManual(model_instance.model, model_stats,ranks = ranks, ft_every = 1, conv2d_nn_decomposition='tucker2', nglobal_compress_iters = 1)

        compressor.decompositions = {k:'tucker2' for k in compressor.decompositions.keys()}

        while not compressor.done:
            #print("\n Compress")
            compressor.compression_step()

            #print("\n Calibrate")
            #compressor.model = calibrate(compressor.compressed_model, device, train_dl,freeze_lnames = lnames_to_compress[:idx])

            compressor.compressed_model = compressor.compressed_model.to(device)

            #macs = calc_macs(compressor.compressed_model, (3, data_config["IMG_SIZE"], data_config["IMG_SIZE"]))
            #print(f"macs: {macs}")

            #print("\n Test")
            #test(compressor.compressed_model, device, val_dl)

            #print('\n Fine-tune')

        model_instance.model = compressor.compressed_model
    
    return model_instance.model

In [None]:
def train(model_instance, model_path, optimizer, scheduler, criterion, scaler, train_dl, val_dl, device):
    
    n_epoch = 30

    best_test_acc = -1.0
    best_test_f1 = -1.0

    num_classes = 9

    label_list = [i for i in range(num_classes)]

    for epoch in range(n_epoch):
        running_loss, correct, total = 0.0, 0, 0
        preds, gt = [], []
        pbar = notebook.tqdm(enumerate(train_dl), total=len(train_dl))
        model_instance.train()
        for batch, (data, labels) in pbar:

            data, labels = data.to(device), labels.to(device)

            if scaler:
                with torch.cuda.amp.autocast():
                    outputs = model_instance(data)
            else:
                outputs = model_instance(data)
            outputs = torch.squeeze(outputs)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()

            if scaler:
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                optimizer.step()
            scheduler.step()

            _, pred = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (pred == labels).sum().item()
            preds += pred.to("cpu").tolist()
            gt += labels.to("cpu").tolist()

            running_loss += loss.item()
            pbar.update()
            pbar.set_description(
                f"Train: [{epoch + 1:03d}] "
                f"Loss: {(running_loss / (batch + 1)):.3f}, "
                f"Acc: {(correct / total) * 100:.2f}% "
                f"F1(macro): {f1_score(y_true=gt, y_pred=preds, labels=label_list, average='macro', zero_division=0):.2f}"
            )
        pbar.close()

        _, test_f1, test_acc = test(
            model=model_instance, test_dataloader=val_dl
        )
        if best_test_f1 > test_f1:
            continue
        best_test_acc = test_acc
        best_test_f1 = test_f1
        print(f"Model saved. Current best test f1: {best_test_f1:.3f}")
        save_model(
            model=model_instance,
            path=model_path,
            data=data,
            device=device,
        )
    
    return best_test_acc,best_test_f1

In [None]:
@torch.no_grad()
def test(model,test_dataloader):
    """Test model.

    Args:
        test_dataloader: test data loader module which is a iterator that returns (data, labels)

    Returns:
        loss, f1, accuracy
    """

    #n_batch = _get_n_batch_from_dataloader(test_dataloader)

    running_loss = 0.0
    preds = []
    gt = []
    correct = 0
    total = 0

    num_classes = 9
    label_list = [i for i in range(num_classes)]

    pbar = notebook.tqdm(enumerate(test_dataloader), total=len(test_dataloader))
    model.to(device)
    model.eval()
    for batch, (data, labels) in pbar:
        data, labels = data.to(device), labels.to(device)

        if scaler:
            with torch.cuda.amp.autocast():
                outputs = model(data)
        else:
            outputs = model(data)
        outputs = torch.squeeze(outputs)
        running_loss += criterion(outputs, labels).item()

        _, pred = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (pred == labels).sum().item()

        preds += pred.to("cpu").tolist()
        gt += labels.to("cpu").tolist()
        pbar.update()
        pbar.set_description(
            f" Val: {'':5} Loss: {(running_loss / (batch + 1)):.3f}, "
            f"Acc: {(correct / total) * 100:.2f}% "
            f"F1(macro): {f1_score(y_true=gt, y_pred=preds, labels=label_list, average='macro', zero_division=0):.2f}"
        )
    loss = running_loss / len(test_dataloader)
    accuracy = correct / total
    f1 = f1_score(
        y_true=gt, y_pred=preds, labels=label_list, average="macro", zero_division=0
    )
    
    return loss, f1, accuracy

## define objective function

In [None]:
def objective_one(trial, device, model_config, data_config, model_path):
    
    #create model_instance
    
    model_instance = Model(model_config,verbose=False)
    
    #filled best pretrained weight
    
    if os.path.isfile(model_path):
        model_instance.model.load_state_dict(torch.load(model_path, map_location=device))
    
    #hyperparameter
    
    rank = {}
    
    for name, param in model_instance.model.named_modules():
        if isinstance(param, nn.Conv2d):
            rank_one = trial.suggest_uniform(name+'_one',0.0,1.0)
            rank_two = trial.suggest_uniform(name+'_two',0.0,1.0)
            param.register_buffer('rank', torch.tensor([rank_one,rank_two]))# rank in, out
            
    compression_model = compression(lnames_to_compress, device, model_instance)
    
    macs = calc_macs(compression_model, (3, data_config["IMG_SIZE"], data_config["IMG_SIZE"]))

    return macs

In [None]:
def objective_two(trial, device, model_config, data_config, model_path, original_macs):
    
    #create model_instance
    
    model_instance = Model(model_config,verbose=False)
    
    #filled best pretrained weight
    
    if os.path.isfile(model_path):
        model_instance.model.load_state_dict(torch.load(model_path, map_location=device))
    
    #fixed hyperparameter 

    train_dl, val_dl, test_dl = create_dataloader(data_config)

    train_path = os.path.join(data_config["DATA_PATH"], "train")
    save_path = os.path.join(log_dir, "best.pt")

    criterion = CustomCriterion(
        samples_per_cls=get_label_counts(train_path)
        if data_config["DATASET"] == "TACO"
        else None,
        device=device,
        #loss_type="weighted"
        #loss_type="customloss"
        #loss_type="label_smoothing"
    )
    
    # Amp loss scaler
    scaler = (
        torch.cuda.amp.GradScaler() if data_config['FP16'] and device != torch.device("cpu") else None
    )
    #scaler=None
    
    #rank hyperparameter
    
    rank = {}
    
    for name, param in model_instance.model.named_modules():
        if isinstance(param, nn.Conv2d):
            rank_one = trial.suggest_uniform(name+'_one',0.0,1.0)
            rank_two = trial.suggest_uniform(name+'_two',0.0,1.0)
            param.register_buffer('rank', torch.tensor([rank_one,rank_two]))# rank in, out
            
    compression_model = compression(lnames_to_compress, device, model_instance)
    
    macs = calc_macs(compression_model, (3, data_config["IMG_SIZE"], data_config["IMG_SIZE"]))
    print(f"macs: {macs}")
    
    
    if macs>original_macs:      ########
        print(f' trial: {trial.number}, This model has very large macs:{macs}')
        raise optuna.structs.TrialPruned()##############
    
    # Create optimizer, scheduler, criterion
    optimizer = torch.optim.SGD(
        compression_model.parameters(), lr=0.1, momentum=0.9
    )
    
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer=optimizer,
        max_lr=data_config["INIT_LR"],
        steps_per_epoch=len(train_dl),
        epochs=30,
        pct_start=0.05,
    )

    #scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10,30,50,70,90,110,130,150,170,190], gamma=0.5)
    
    
    _, best_f1 = train(
        compression_model,
        save_path,
        optimizer,
        scheduler, 
        criterion, 
        scaler, 
        train_dl, 
        val_dl, 
        device
    )

    return best_f1, macs

In [None]:
def tune_one(device, model_config, data_config, model_path, study_name= "pstage_automl"):
    
    sampler = optuna.samplers.TPESampler(n_startup_trials=20)

    study = optuna.create_study(
        direction="minimize",
        study_name=study_name,
        sampler=sampler,
        load_if_exists=True
    )
    
    study.optimize(lambda trial: objective_one(trial, device, model_config, data_config, model_path), n_trials=15000)

    pruned_trials = [
        t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED
    ]
    complete_trials = [
        t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE
    ]

    print("Study statistics: ")
    print("  Number of finished trials: ", len(study.trials))
    print("  Number of pruned trials: ", len(pruned_trials))
    print("  Number of complete trials: ", len(complete_trials))

    print("Best trials:")
    best_trials = study.best_trials

    ## trials that satisfies Pareto Fronts
    for tr in best_trials:
        print(f"  value:{tr.values}")
        for key, value in tr.params.items():
            print(f"    {key}:{value}")

    #best_trial = get_best_trial_with_condition(study)
    
    return study

In [None]:
def tune_two(device, model_config, data_config, model_path, study_name= "pstage_automl"):

    sampler = optuna.samplers.MOTPESampler(n_startup_trials=20)

    study = optuna.create_study(
        directions=["maximize", "minimize"],
        study_name=study_name,
        sampler=sampler,
        load_if_exists=True
    )
    
    study.optimize(lambda trial: objective_two(trial, device, model_config, data_config, model_path,original_macs), n_trials=100)

    pruned_trials = [
        t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED
    ]
    complete_trials = [
        t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE
    ]

    print("Study statistics: ")
    print("  Number of finished trials: ", len(study.trials))
    print("  Number of pruned trials: ", len(pruned_trials))
    print("  Number of complete trials: ", len(complete_trials))

    print("Best trials:")
    best_trials = study.best_trials

    ## trials that satisfies Pareto Fronts
    for tr in best_trials:
        print(f"  value1:{tr.values[0]}, value2:{tr.values[1]}")
        for key, value in tr.params.items():
            print(f"    {key}:{value}")

    return study

In [None]:
study_name="pstage_automl"

study = tune_one(device, model_config, data_config, model_path, study_name=study_name)

In [None]:
study_name="pstage_automl2"

study = tune_two(device, model_config, data_config, model_path, study_name=study_name)

## best parameter 

In [None]:
best_params = {'0.0.conv_one': 0.06249513139317668, '0.0.conv_two': 0.7406647857391946, '0.1.conv_one': 0.6363534088192819, '0.1.conv_two': 0.9142387393539608, '0.2.conv_one': 0.1292877468619793, '0.2.conv_two': 0.7383343969092988, '1.0.conv.0.0_one': 0.0008341748802290622, '1.0.conv.0.0_two': 0.05002966054540159, '1.0.conv.1.0_one': 0.8914359902260529, '1.0.conv.1.0_two': 0.002657766591352799, '1.0.conv.2_one': 0.06472996627499569, '1.0.conv.2_two': 0.0018459508866916843, '2.0.conv.0.0_one': 0.9115462763127067, '2.0.conv.0.0_two': 0.6639459793998597, '2.0.conv.1_one': 0.0853569870993198, '2.0.conv.1_two': 0.18126938180383792, '2.1.conv.0.0_one': 0.9817793298392701, '2.1.conv.0.0_two': 0.9001258631290798, '2.1.conv.1_one': 0.281183024423954, '2.1.conv.1_two': 0.8882239594838728, '2.2.conv.0.0_one': 0.17206134527733363, '2.2.conv.0.0_two': 0.8500281545934149, '2.2.conv.1_one': 0.14476380947875786, '2.2.conv.1_two': 0.005596879308822559, '2.3.conv.0.0_one': 0.44316357904486675, '2.3.conv.0.0_two': 0.5941322113050731, '2.3.conv.1_one': 0.572635098239233, '2.3.conv.1_two': 0.8468397676868743, '3.0.conv.0.0_one': 0.21862166138526337, '3.0.conv.0.0_two': 0.07024662744792722, '3.0.conv.1.0_one': 0.04545779266265461, '3.0.conv.1.0_two': 0.9045148663886942, '3.0.conv.2_one': 0.29411044633051125, '3.0.conv.2_two': 0.5847107548343258, '3.1.conv.0.0_one': 0.40518093356151513, '3.1.conv.0.0_two': 0.7475443817808775, '3.1.conv.1.0_one': 0.8043939144749847, '3.1.conv.1.0_two': 0.289022029677154, '3.1.conv.2_one': 0.0014290599572399554, '3.1.conv.2_two': 0.6629448092569628, '3.2.conv.0.0_one': 0.3641895413721229, '3.2.conv.0.0_two': 0.013451955264283952, '3.2.conv.1.0_one': 0.5363319490599914, '3.2.conv.1.0_two': 0.4116140987993332, '3.2.conv.2_one': 0.4913039777560289, '3.2.conv.2_two': 0.8455161001109379, '3.3.conv.0.0_one': 0.08959084143598055, '3.3.conv.0.0_two': 0.8011990428807316, '3.3.conv.1.0_one': 0.9341552726271395, '3.3.conv.1.0_two': 0.9627100873893634, '3.3.conv.2_one': 0.10003841299442714, '3.3.conv.2_two': 0.46928142816001717, '3.4.conv.0.0_one': 0.6525725884055026, '3.4.conv.0.0_two': 0.052133724856946936, '3.4.conv.1.0_one': 0.19177841739710982, '3.4.conv.1.0_two': 0.9318086669838463, '3.4.conv.2_one': 0.41486284037709126, '3.4.conv.2_two': 0.7832709259249746, '4.0.conv.0_one': 0.004875134710417898, '4.0.conv.0_two': 0.3262479183020909, '4.0.conv.3_one': 0.0012859373107223585, '4.0.conv.3_two': 0.19911844284242933, '4.0.conv.5.fc1_one': 0.32973034456086264, '4.0.conv.5.fc1_two': 0.30977965845059585, '4.0.conv.5.fc2_one': 0.16132036957829732, '4.0.conv.5.fc2_two': 0.9999374936964174, '4.0.conv.7_one': 0.09044786189188206, '4.0.conv.7_two': 0.10948415544559788, '4.1.conv.0_one': 2.315542203615295e-05, '4.1.conv.0_two': 0.08849056846575871, '4.1.conv.3_one': 0.769029022274578, '4.1.conv.3_two': 0.4263477411613994, '4.1.conv.5.fc1_one': 0.42351845543985933, '4.1.conv.5.fc1_two': 0.00549091928708752, '4.1.conv.5.fc2_one': 0.08237916516875739, '4.1.conv.5.fc2_two': 0.2120508770543559, '4.1.conv.7_one': 0.47650604499479804, '4.1.conv.7_two': 0.0006197777375650953, '5.conv_one': 0.00010686396333617018, '5.conv_two': 0.2755268916624911, '7.conv_one': 0.17916232010837874, '7.conv_two': 0.12058862027715131}

In [None]:
study.best_trial.params

## fine tuning

In [None]:
#load model

model_config = read_yaml(cfg="exp/0.5177_100epoch_1120/model.yml")
data_config = read_yaml(cfg="exp/0.5177_100epoch_1120/data.yml")

model_config = read_yaml(cfg=model_config)
data_config = read_yaml(cfg=data_config)

In [None]:
model_instance = Model(model_config,verbose=True)

idx |   n |     params |          module |            arguments |   in_channel |   out_channel
----------------------------------------------------------------------------------------------
  0 |   3 |        816 |          DWConv | [16, 3, 2, None, 'ReLU'] |            3           16
  1 |   1 |      2,016 | InvertedResidualv2 |           [32, 2, 2] |           16           32
  2 |   4 |      2,288 | InvertedResidualv2 |           [16, 1, 2] |           32           16
  3 |   5 |      7,360 | InvertedResidualv2 |           [16, 2, 2] |           16           16
  4 |   2 |    240,656 | InvertedResidualv3 | [5, 3.5, 128, 1, 1, 2] |           16          128
  5 |   1 |     83,200 |            Conv |          [640, 1, 1] |          128          640
  6 |   1 |          0 |   GlobalAvgPool |                   [] |          640          640
  7 |   1 |      5,778 |       FixedConv | [9, 1, 1, None, 1, None] |          640            9
Model Summary: 161 layers, 342,114 parameters, 342,1

In [None]:
model_path = 'exp/0.5177_100epoch_1120/best.pt'

if os.path.isfile(model_path):
    model_instance.model.load_state_dict(torch.load(model_path, map_location=device))

In [None]:
lnames_to_compress

In [None]:
#refine rank dictionary

rank = {}

for lname in lnames_to_compress:
    
    rank[lname] = [best_params[lname+'_one'],best_params[lname+'_two']]

In [None]:
rank

In [None]:
#create rank buffer

for name, param in model_instance.model.named_modules():
    if isinstance(param, nn.Conv2d):
        if name in lnames_to_compress:
            param.register_buffer('rank', torch.tensor(rank[name]))# rank in, out

In [None]:
#compression

for lname in lnames_to_compress:

        ranks =  {k:None for k in all_layer}

        for name, param in model_instance.model.named_modules():
            if lname == name:
                if param.groups == 1:
                    tensor_rank = getattr(param, "rank")
                    rank = [int(r * param.weight.shape[i]) for i, r in enumerate(tensor_rank)]
                    ranks[lname] = [max(r, 2) for r in rank]
                    break

        if ranks[lname] == None:
            continue

        compressor = CompressorManual(model_instance.model, model_stats,ranks = ranks, ft_every = 1, conv2d_nn_decomposition='tucker2', nglobal_compress_iters = 1)

        compressor.decompositions = {k:'tucker2' for k in compressor.decompositions.keys()}

        while not compressor.done:
            #print("\n Compress")
            compressor.compression_step()

            #print("\n Calibrate")
            #compressor.model = calibrate(compressor.compressed_model, device, train_dl,freeze_lnames = lnames_to_compress[:idx])

            compressor.compressed_model = compressor.compressed_model.to(device)

            #macs = calc_macs(compressor.compressed_model, (3, data_config["IMG_SIZE"], data_config["IMG_SIZE"]))
            #print(f"macs: {macs}")

            #print("\n Test")
            #test(compressor.compressed_model, device, val_dl)

            #print('\n Fine-tune')

        model_instance.model = compressor.compressed_model

In [None]:
#calculate macs

macs = calc_macs(model_instance.model, (3, data_config["IMG_SIZE"], data_config["IMG_SIZE"]))
print(f"macs: {macs}")

In [None]:
log_dir = os.path.join("exp", datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
os.makedirs(log_dir, exist_ok=True)

In [None]:
train_dl, val_dl, test_dl = create_dataloader(data_config)

In [None]:
#hyperparameter 

train_path = os.path.join(data_config["DATA_PATH"], "train")
model_path = os.path.join(log_dir, "best.pt")

# Create optimizer, scheduler, criterion
optimizer = torch.optim.SGD(
    model_instance.model.parameters(), lr=0.1, momentum=0.9
)


#scheduler = torch.optim.lr_scheduler.OneCycleLR(
    #optimizer=optimizer,
    #max_lr=data_config["INIT_LR"],
    #steps_per_epoch=len(train_dl),
    #epochs=200,
    #pct_start=0.05,
#)

scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10,30,50,70,90,110,130,150,170,190], gamma=0.5)

criterion = CustomCriterion(
    samples_per_cls=get_label_counts(train_path)
    if data_config["DATASET"] == "TACO"
    else None,
    device=device,
    #loss_type="weighted"
    #loss_type="customloss"
    #loss_type="label_smoothing"
)


# Amp loss scaler
scaler = (
    torch.cuda.amp.GradScaler() if data_config['FP16'] and device != torch.device("cpu") else None
)
#scaler=None

In [None]:
#test function

@torch.no_grad()
def test(model,test_dataloader):
    """Test model.

    Args:
        test_dataloader: test data loader module which is a iterator that returns (data, labels)

    Returns:
        loss, f1, accuracy
    """

    #n_batch = _get_n_batch_from_dataloader(test_dataloader)

    running_loss = 0.0
    preds = []
    gt = []
    correct = 0
    total = 0

    num_classes = 9
    label_list = [i for i in range(num_classes)]

    pbar = notebook.tqdm(enumerate(test_dataloader), total=len(test_dataloader))
    model.to(device)
    model.eval()
    for batch, (data, labels) in pbar:
        data, labels = data.to(device), labels.to(device)

        if scaler:
            with torch.cuda.amp.autocast():
                outputs = model(data)
        else:
            outputs = model(data)
        outputs = torch.squeeze(outputs)
        running_loss += criterion(outputs, labels).item()

        _, pred = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (pred == labels).sum().item()

        preds += pred.to("cpu").tolist()
        gt += labels.to("cpu").tolist()
        pbar.update()
        pbar.set_description(
            f" Val: {'':5} Loss: {(running_loss / (batch + 1)):.3f}, "
            f"Acc: {(correct / total) * 100:.2f}% "
            f"F1(macro): {f1_score(y_true=gt, y_pred=preds, labels=label_list, average='macro', zero_division=0):.2f}"
        )
    loss = running_loss / len(test_dataloader)
    accuracy = correct / total
    f1 = f1_score(
        y_true=gt, y_pred=preds, labels=label_list, average="macro", zero_division=0
    )
    
    return loss, f1, accuracy

In [None]:
#basic training
n_epoch = 200

best_test_acc = -1.0
best_test_f1 = -1.0

num_classes = 9

label_list = [i for i in range(num_classes)]

for epoch in range(n_epoch):
    running_loss, correct, total = 0.0, 0, 0
    preds, gt = [], []
    pbar = notebook.tqdm(enumerate(train_dl), total=len(train_dl))
    model_instance.model.train()
    for batch, (data, labels) in pbar:
        
        data, labels = data.to(device), labels.to(device)

        if scaler:
            with torch.cuda.amp.autocast():
                outputs = model_instance.model(data)
        else:
            outputs = model_instance.model(data)
        outputs = torch.squeeze(outputs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()

        if scaler:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
        scheduler.step()

        _, pred = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (pred == labels).sum().item()
        preds += pred.to("cpu").tolist()
        gt += labels.to("cpu").tolist()

        running_loss += loss.item()
        pbar.update()
        pbar.set_description(
            f"Train: [{epoch + 1:03d}] "
            f"Loss: {(running_loss / (batch + 1)):.3f}, "
            f"Acc: {(correct / total) * 100:.2f}% "
            f"F1(macro): {f1_score(y_true=gt, y_pred=preds, labels=label_list, average='macro', zero_division=0):.2f}"
        )
    pbar.close()

    _, test_f1, test_acc = test(
        model=model_instance.model, test_dataloader=val_dl
    )
    if best_test_f1 > test_f1:
        continue
    best_test_acc = test_acc
    best_test_f1 = test_f1
    print(f"Model saved. Current best test f1: {best_test_f1:.3f}")
    save_model(
        model=model_instance.model,
        path=model_path,
        data=data,
        device=device,
    )


## calibrate compression

In [None]:
def calibrate(model, device, train_loader, max_iters = 1000,
              freeze_lnames = None):

    model.to(device).train()
    for pname, p in model.named_parameters():
        
        if pname.strip('.weight').strip('.bias')  in freeze_lnames:
            p.requires_grad = False

    with torch.no_grad():
        for i, (data, _) in notebook.tqdm(enumerate(train_loader)):
            _ = model(data.to(device))

            if i > max_iters:
                break

            del data
            torch.cuda.empty_cache()
            
    model.eval()
    return model

In [None]:
for idx,lname in enumerate(lnames_to_compress):

        ranks =  {k:None for k in all_layer}

        for name, param in model_instance.model.named_modules():
            if lname == name:
                if param.groups == 1:
                    tensor_rank = getattr(param, "rank")
                    rank = [int(r * param.weight.shape[i]) for i, r in enumerate(tensor_rank)]
                    ranks[lname] = [max(r, 2) for r in rank]
                    break

        if ranks[lname] == None:
            continue

        compressor = CompressorManual(model_instance.model, model_stats,ranks = ranks, ft_every = 1, conv2d_nn_decomposition='tucker2', nglobal_compress_iters = 1)

        compressor.decompositions = {k:'tucker2' for k in compressor.decompositions.keys()}

        while not compressor.done:
            print("\n Compress")
            compressor.compression_step()

            print("\n Calibrate")
            compressor.model = calibrate(compressor.compressed_model, device, train_dl,freeze_lnames = lnames_to_compress[:idx])

            compressor.compressed_model = compressor.compressed_model.to(device)

            #macs = calc_macs(compressor.compressed_model, (3, data_config["IMG_SIZE"], data_config["IMG_SIZE"]))
            #print(f"macs: {macs}")

            #print("\n Test")
            #test(compressor.compressed_model, device, val_dl)

            #print('\n Fine-tune')

        model_instance.model = compressor.compressed_model

In [None]:
#caculate macs
macs = calc_macs(model_instance.model, (3, data_config["IMG_SIZE"], data_config["IMG_SIZE"]))
print(f"macs: {macs}")

In [None]:
train_dl, val_dl, test_dl = create_dataloader(data_config)

In [None]:
#hyperparameter 

train_path = os.path.join(data_config["DATA_PATH"], "train")
model_path = os.path.join(log_dir, "best.pt")

# Create optimizer, scheduler, criterion
optimizer = torch.optim.SGD(
    model_instance.model.parameters(), lr=0.1, momentum=0.9
)


scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer=optimizer,
    max_lr=data_config["INIT_LR"],
    steps_per_epoch=len(train_dl),
    epochs=200,
    pct_start=0.05,
)

#scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10,30,50,70,90,110,130,150,170,190], gamma=0.5)

criterion = CustomCriterion(
    samples_per_cls=get_label_counts(train_path)
    if data_config["DATASET"] == "TACO"
    else None,
    device=device,
    #loss_type="weighted"
    #loss_type="customloss"
    #loss_type="label_smoothing"
)


# Amp loss scaler
scaler = (
    torch.cuda.amp.GradScaler() if data_config['FP16'] and device != torch.device("cpu") else None
)
#scaler=None

In [None]:
n_epoch = 200

best_test_acc = -1.0
best_test_f1 = -1.0

num_classes = 9

label_list = [i for i in range(num_classes)]

for epoch in range(n_epoch):
    running_loss, correct, total = 0.0, 0, 0
    preds, gt = [], []
    pbar = notebook.tqdm(enumerate(train_dl), total=len(train_dl))
    model_instance.model.train()
    for batch, (data, labels) in pbar:
        
        data, labels = data.to(device), labels.to(device)

        if scaler:
            with torch.cuda.amp.autocast():
                outputs = model_instance.model(data)
        else:
            outputs = model_instance.model(data)
        outputs = torch.squeeze(outputs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()

        if scaler:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
        scheduler.step()

        _, pred = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (pred == labels).sum().item()
        preds += pred.to("cpu").tolist()
        gt += labels.to("cpu").tolist()

        running_loss += loss.item()
        pbar.update()
        pbar.set_description(
            f"Train: [{epoch + 1:03d}] "
            f"Loss: {(running_loss / (batch + 1)):.3f}, "
            f"Acc: {(correct / total) * 100:.2f}% "
            f"F1(macro): {f1_score(y_true=gt, y_pred=preds, labels=label_list, average='macro', zero_division=0):.2f}"
        )
    pbar.close()

    _, test_f1, test_acc = test(
        model=model_instance.model, test_dataloader=val_dl
    )
    if best_test_f1 > test_f1:
        continue
    best_test_acc = test_acc
    best_test_f1 = test_f1
    print(f"Model saved. Current best test f1: {best_test_f1:.3f}")
    save_model(
        model=model_instance.model,
        path=model_path,
        data=data,
        device=device,
    )
