In [8]:
import argparse
import logging
import math
import os
import random
import time
from copy import deepcopy
from pathlib import Path
from threading import Thread

import numpy as np
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.utils.data
import yaml
from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

import test  # import test.py to get mAP after each epoch
from models.experimental import attempt_load
from models.yolo import Model
from utils.autoanchor import check_anchors
from utils.datasets import create_dataloader
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
    fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
    check_requirements, print_mutation, set_logging, one_cycle, colorstr
from utils.google_utils import attempt_download
from utils.loss import ComputeLoss, ComputeLossOTA
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel
from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume
from models.common import CBAM
import json
import pandas as pd

logger = logging.getLogger(__name__)


def detect(cfg, nc, hyp, device):
    with open(hyp) as f:
        hyp = yaml.load(f, Loader=yaml.SafeLoader)  # load hyps
    model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)  # create
    stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32
    img_size = [640, 640]
    with open('./temp_model_info.json') as json_file:
        model_info = json.load(json_file)
    
    with open('./model_flop.json') as flopfile:
        model_flop = json.load(flopfile)

    params = model_flop['model'][1]
    model_flop = model_flop["model"][2]
    model_flop_dict= {key: model_flop[key][0] for key in model_flop}

    for key in model_info:
        model_info[key]["flop"] = int(model_flop_dict[key])
    

    pd.set_option('display.max_rows', None)
    pd.set_option('display.max_columns', 5)
    df = pd.DataFrame.from_dict(model_info, orient='index')
    df.drop(columns=["from"], inplace=True)
    df.reset_index(inplace=True)
    df.rename(columns={'index': 'index'}, inplace=True)

    print(df)
    flops = df["flop"].sum()
    print(f"Total flop: {flops * img_size[0] / stride * img_size[1] / stride} GFLOPS")
    print(f"Total params: {params}")
    


    


if __name__ == '__main__':

    detect(cfg="./cfg/training/v4-swintf-dwconv.yaml", nc=10, hyp="./data/hyp.scratch.p5.yaml", device=torch.device("cpu"))


   index  n  params                                module     flop
0      0  1     928                    models.common.Conv  1015808
1      1  1   18560                    models.common.Conv  4784128
2      2  1     704             n DWConv at 0x7f96f79073a   212992
3      3  1    2112                    models.common.Conv   557056
4      4  1    2112                    models.common.Conv   557056
5      5  1    9280                    models.common.Conv  2392064
6      6  1    9280                    models.common.Conv  2392064
7      7  1       0                  models.common.Concat        0
8      8  1    8320                    models.common.Conv  2162688
9      9  1       0                      models.common.MP        0
10    10  1     128             n DWConv at 0x7f96f79073a    12288
11    11  1     640             n DWConv at 0x7f96f79073a    45056
12    12  1       0                  models.common.Concat        0
13    13  1     384             n DWConv at 0x7f96f79073a    4