## FLOPS calculation

In [9]:
import os
import sys
import time
import numpy as np
import torch
import copy
import logging
from fvcore.nn import FlopCountAnalysis

# Suppress fvcore warnings to keep output clean
logging.getLogger("fvcore").setLevel(logging.ERROR)

# -----------------------------------------------------------------------------
# 1. SETUP PATHS
# -----------------------------------------------------------------------------
current_dir = os.getcwd()
parent_dir = current_dir
while os.path.basename(parent_dir) != "cellPIV" and parent_dir != "/":
    parent_dir = os.path.dirname(parent_dir)

if parent_dir not in sys.path:
    sys.path.append(parent_dir)
    print(f"Added to sys.path: {parent_dir}")

# -----------------------------------------------------------------------------
# 2. IMPORT PROJECT MODULES
# -----------------------------------------------------------------------------
try:
    from config import Config_03_train as conf_train
    from _03_train._b_LSTMFCN import TimeSeriesClassifier
    from _99_ConvTranModel.model import model_factory
    import _utils_._utils as utils
    
    # ROCKET Import
    try:
        from tsai.models.Rocket import Rocket
    except ImportError:
        try:
            from sktime.transformations.panel.rocket import Rocket
        except ImportError:
            Rocket = None
            
    print("Successfully imported project modules.")
except ImportError as e:
    print(f"Error importing modules: {e}")

# -----------------------------------------------------------------------------
# 3. METRICS HELPERS
# -----------------------------------------------------------------------------
def get_sequence_length(hours):
    # Based on config.py: framePerHour = 4
    return int(hours * 4)

def count_parameters(model):
    """Returns count of trainable parameters in Millions (M)"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6

def measure_inference_time(model, dummy_input, device="cpu", repetitions=100):
    """Measures average inference time in milliseconds (ms)"""
    model.eval()
    model.to(device)
    dummy_input = dummy_input.to(device)

    # Warmup
    with torch.no_grad():
        for _ in range(10): _ = model(dummy_input)
    
    # Measure
    if device == "cuda": torch.cuda.synchronize()
    start = time.time()
    with torch.no_grad():
        for _ in range(repetitions): _ = model(dummy_input)
    if device == "cuda": torch.cuda.synchronize()
    end = time.time()
    
    return ((end - start) / repetitions) * 1000

def measure_rocket_time(input_shape, repetitions=50):
    """Simulates ROCKET inference time (Transform + Linear)"""
    if Rocket is None: return "N/A"
    try:
        # 1. Prepare Input as NUMPY (ROCKET usually requires CPU Numpy)
        # input_shape is (1, 1, T)
        X_np = np.random.randn(*input_shape).astype(np.float32)
        
        # 2. Init ROCKET
        # tsai Rocket usually takes (c_in, seq_len)
        try:
            rocket = Rocket(c_in=input_shape[1], seq_len=input_shape[2])
        except:
            rocket = Rocket() # sktime fallback
            
        # 3. Fit (required once)
        # Pass numpy array to fit
        if hasattr(rocket, 'fit'): 
            rocket.fit(X_np)
            
        # 4. Measure Transform Time
        start = time.time()
        for _ in range(repetitions):
            # Transform
            if hasattr(rocket, 'transform'):
                feats = rocket.transform(X_np)
            else:
                feats = rocket(X_np) # callable
            
            # Linear Classifier Simulation (Dot product)
            # feats might be a tensor or numpy array depending on library
            if isinstance(feats, torch.Tensor):
                feats = feats.detach().cpu().numpy()
            
            # Simulate linear layer: features @ weights
            # feats shape: (Batch, N_Kernels)
            _ = feats @ np.random.randn(feats.shape[1], 1)
            
        end = time.time()
        return ((end - start) / repetitions) * 1000
        
    except Exception as e:
        return f"Err: {str(e)[:20]}"

# -----------------------------------------------------------------------------
# 4. MAIN LOOP
# -----------------------------------------------------------------------------
day_map = {1: 24, 3: 72, 5: 120}
device = "cpu" # CPU for fair comparison

results_log = []
output_file = os.path.join(parent_dir, "paper_figures", "results_computational_load.txt")
os.makedirs(os.path.dirname(output_file), exist_ok=True)

header = f"{'Model':<20} | {'Params (M)':<12} | {'FLOPs (G)':<12} | {'Time (ms)':<12}"
print(f"\n{header}")
print("-" * len(header))
results_log.append(header)

base_models_path = conf_train.output_model_base_dir

for day in [1, 3, 5]:
    hours = day_map[day]
    T = get_sequence_length(hours)
    # Input shape: (Batch=1, Channels=1, Time)
    input_shape = (1, 1, T)
    dummy_input = torch.randn(*input_shape)

    # --- 1. LSTMFCN ---
    try:
        model = TimeSeriesClassifier(input_channels=1, num_classes=2)
        params = count_parameters(model)
        
        # Calculate FLOPs (suppressing warnings)
        flops = FlopCountAnalysis(model, dummy_input)
        flops.unsupported_ops_warnings(False) # Clean output
        flops_g = flops.total() / 1e9
        
        time_ms = measure_inference_time(model, dummy_input, device=device)
        
        msg = f"LSTMFCN_{hours}h".ljust(20) + f" | {params:.4f}".ljust(15) + f" | {flops_g:.6f}".ljust(15) + f" | {time_ms:.4f}"
        print(msg)
        results_log.append(msg)
    except Exception as e:
        err = f"LSTMFCN_{hours}h: {str(e)}"
        print(err)
        results_log.append(err)

    # --- 2. ConvTran ---
    try:
        conf_local = copy.deepcopy(conf_train)
        conf_local.Data_shape = (1, T)
        if not hasattr(conf_local, 'num_labels'): conf_local.num_labels = 2
        
        model = model_factory(conf_local)
        params = count_parameters(model)
        
        flops = FlopCountAnalysis(model, dummy_input)
        flops.unsupported_ops_warnings(False)
        flops_g = flops.total() / 1e9
        
        time_ms = measure_inference_time(model, dummy_input, device=device)
        
        msg = f"ConvTran_{hours}h".ljust(20) + f" | {params:.4f}".ljust(15) + f" | {flops_g:.6f}".ljust(15) + f" | {time_ms:.4f}"
        print(msg)
        results_log.append(msg)
    except Exception as e:
        err = f"ConvTran_{hours}h: {str(e)}"
        print(err)
        results_log.append(err)

    # --- 3. ROCKET ---
    try:
        time_ms = measure_rocket_time(input_shape)
        if isinstance(time_ms, float):
            t_str = f"{time_ms:.4f}"
        else:
            t_str = time_ms
            
        msg = f"ROCKET_{hours}h".ljust(20) + f" | {'N/A':<12} | {'N/A':<12} | {t_str}"
        print(msg)
        results_log.append(msg)
    except Exception as e:
        err = f"ROCKET_{hours}h: {str(e)}"
        print(err)
        results_log.append(err)
        
    print("-" * 60)
    results_log.append("-" * 60)

# Save
with open(output_file, "w") as f:
    f.write("\n".join(results_log))
print(f"Results saved to {output_file}")

Successfully imported project modules.

Model                | Params (M)   | FLOPs (G)    | Time (ms)   
-----------------------------------------------------------------
LSTMFCN_24h          | 0.9850       | 0.054563     | 10.4716
ConvTran_24h         | 0.1892       | 0.013861     | 1.0281
ROCKET_24h           | N/A          | N/A          | 4.3663
------------------------------------------------------------
LSTMFCN_72h          | 0.9850       | 0.177345     | 11.4418
ConvTran_72h         | 0.1923       | 0.055739     | 2.6235
ROCKET_72h           | N/A          | N/A          | 9.7241
------------------------------------------------------------
LSTMFCN_120h         | 0.9850       | 0.319001     | 17.7284
ConvTran_120h        | 0.1953       | 0.116490     | 5.4947
ROCKET_120h          | N/A          | N/A          | 15.0445
------------------------------------------------------------
Results saved to /home/phd2/Scrivania/CorsoRepo/cellPIV/paper_figures/results_computational_load.txt
