In [None]:
!python -m pip install --no-index --find-links=/kaggle/input/byu-dataset/byu_dataset/packages -r /kaggle/input/byu-dataset/byu_dataset/packages/requirements.txt
print("done")

QUANTILE_THRESHOLD= 0.545 # 3xTTA

In [None]:
import os
import sys
sys.path.append('/kaggle/input/byu-dataset/byu_dataset/src')

# Dir to store imgs
for path in ["/tmp/working/"]:
    if not os.path.exists(path):
        os.mkdir(path)
print("done")

In [None]:
import os
import shutil

def clean_working(directory_path: str = "/kaggle/working/"):
    """
    Clean kaggle output directory.
    """
    if os.path.exists(directory_path):
        for item in os.listdir(directory_path):
            if item == "submission.csv":
                continue
            item_path = os.path.join(directory_path, item)
            os.remove(item_path) if os.path.isfile(item_path) else shutil.rmtree(item_path)
        print(f"All items in '{directory_path}' have been removed.")
    else:
        print(f"'{directory_path}' does not exist.")
        
clean_working()
clean_working("/tmp/working/")
print("done")

## Infer

In [None]:
%%writefile ddp.py

from types import SimpleNamespace
import os
import sys
import json

if 'KAGGLE_URL_BASE' not in os.environ:
    os.environ['CUDA_VISIBLE_DEVICES']= "2,3"
else:
    sys.path.append('/kaggle/input/byu-dataset/byu_dataset/src')

import glob
import pickle
from copy import deepcopy
from tqdm import tqdm

import pandas as pd 
import numpy as np

import torch
import torch.nn as nn 
import torch.nn.functional as F
from torch.amp import autocast, GradScaler

from monai.inferers import sliding_window_inference

from src.modules.utils import batch_to_device
from src.models.utils import get_model
from src.data._3d import CustomDataset
from src.utils.torch import center_of_mass_3d


import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler


def run_inference(rank, world_size):
    # ========== Config ==========
    if is_kaggle():
        c= SimpleNamespace()
        c.working_dir= "/tmp/working/"
        c.model_dir= "/kaggle/input/byu-dataset/byu_dataset/byu_models/*999*.pt"
        c.img_size= (128, 672, 672)
        c.in_dir= "/kaggle/input/byu-locating-bacterial-flagellar-motors-2025/test/"
        c.script_idx= int(os.environ.get("SCRIPT_IDX"))
        c.local_rank= rank
    else:
        raise ValueError("Are we in the Matrix?")

    # ========== Dataframe ===========
    df= []
    for tomo_id in os.listdir(c.in_dir):
        df.append({"tomo_id": tomo_id, "fold": -1})
    df= pd.DataFrame(df)

    # ========== Models ===========
    models= []
    mpaths= sorted(glob.glob(c.model_dir))
    total_models= len(mpaths)
    mid= len(mpaths)//2
    if c.script_idx == 0:
        mpaths= mpaths[:mid + 1]
    else:
        mpaths= mpaths[mid:]
        
    
    if c.local_rank == 0: 
        print("="*25)
        print("Loading {}/{} models..".format(len(mpaths), total_models))
        for m in mpaths:
            print(m)
            
    for fpath in mpaths:
        fpath_cfg= fpath.replace(".pt", ".pkl")
        
        # Load cfg
        with open(fpath_cfg, "rb") as f:
            model_cfg= pickle.load(f)
            
        if not hasattr(model_cfg, "infer_cfg"):
            model_cfg.infer_cfg = SimpleNamespace()
            
        model_cfg.local_rank= rank
        model_cfg.weights_path= fpath
        model_cfg.data_dir= c.in_dir
        model_cfg.infer_cfg.sw_batch_size= 2
        model_cfg.tta= True
        model_cfg.img_size= c.img_size
        model_cfg.infer_cfg.overlap= (0.875, 0.25, 0.25)

        # Load model
        m, _= get_model(model_cfg, inference_mode=True)
        m= m.to(rank)
        m= DistributedDataParallel(
            m, 
            device_ids= [rank], 
            output_device= rank,
            )
        m= m.eval()
        m= m.to(rank)

        models.append({
            "model": m,
            "cfg": model_cfg,
        })
        
    if c.local_rank == 0: 
        print("="*25)
    cfg= deepcopy(model_cfg)

    # ========== Datasets / Dataloader ===========
    test_ds= CustomDataset(
        cfg=model_cfg, 
        df=df, 
        mode="test",
        )

    sampler= DistributedSampler(
        test_ds,
        num_replicas= world_size,
        rank= rank,
    )

    test_dl= torch.utils.data.DataLoader(
        dataset=test_ds,
        batch_size=1, 
        num_workers=2,
        sampler=sampler,
        pin_memory=False,
        shuffle=False, 
        drop_last=False,
    )

    # ========== Inference Loop ===========
    preds_final= []
    with torch.no_grad():
        
        # ROI weight map (downweights edge predictions)
        pct = 0.30  # 30% edge
        z, h, w = c.img_size
        z_margin = int(z * pct)
        h_margin = int(h * pct)
        w_margin = int(w * pct)
        roi_weight_map = torch.ones((z, h, w), device=rank)  # Initialize everything as 1.0
        roi_weight_map[:z_margin] = 1e-3  # Top edge
        roi_weight_map[-z_margin:] = 1e-3  # Bottom edge
        roi_weight_map[:, :h_margin] = 1e-3  # Left edge
        roi_weight_map[:, -h_margin:] = 1e-3  # Right edge
        roi_weight_map[:, :, :w_margin] = 1e-3  # Front edge
        roi_weight_map[:, :, -w_margin:] = 1e-3  # Back edge

        for batch in tqdm(test_dl):
            with autocast(cfg.device.type):

                try:
                    tomo_id= batch.pop("tomo_id")[0]
                    batch = batch_to_device(batch, device=rank) 
                    batch["input"]= batch["input"].float()

                    # Sliding window
                    preds= None
                    for midx, row in enumerate(models):
                        _preds = sliding_window_inference(
                            inputs= batch["input"],
                            roi_size= row["cfg"].roi_size,
                            predictor= row["model"],
                            roi_weight_map= roi_weight_map,
                            **vars(row["cfg"].infer_cfg)
                        )[0, 0, ...]
                        _preds= torch.sigmoid(_preds)

                        if preds is None:
                            preds = _preds
                        else:
                            preds += _preds

                    # Save Tensor
                    preds= preds.half().cpu()
                    fpath= "{}{}_{}.pt".format(c.working_dir, tomo_id, c.script_idx)
                    torch.save(preds, fpath)
                    outpath= fpath.replace(".pt", ".json")

                    with open(outpath, "w+") as f:
                        json.dump({
                            "tomo_id": tomo_id,
                            "script_id": c.script_idx,
                            "z_shape": batch["z_shape"][0].item(),
                            "y_shape": batch["y_shape"][0].item(),
                            "x_shape": batch["x_shape"][0].item(),
                        }, f)
                except:
                    import traceback
                    print(traceback.format_exc())
                    
                    print("failed tomo_id:", tomo_id)
                    pass

        return c, preds_final


def is_kaggle():
    return os.getenv("KAGGLE_URL_BASE") is not None

def run_DDP(run_fn, world_size):
    mp.spawn(run_fn, args=(world_size,), nprocs=world_size, join=True)
    
def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '9931'
    torch.cuda.set_device(rank)
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup(rank):
    dist.destroy_process_group()
    if rank == 0:
        print("DDP complete.")

def run_all(rank, world_size):
    print(f"Running DDP code on rank {rank}.")
    setup(rank, world_size)

    c, preds= run_inference(rank, world_size)
        
    cleanup(rank)
    return

if __name__ == "__main__":
    # Setting seed for reproducabilty
    np.random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    n_gpus = torch.cuda.device_count()
    print(f"total GPUs: {n_gpus}")
    world_size = n_gpus
    run_DDP(run_all, world_size)

In [None]:
import time
start_time= time.time()

kicr = "true" if os.getenv('KAGGLE_IS_COMPETITION_RERUN') else ""
kub = "true" if 'KAGGLE_URL_BASE' in os.environ else ""

!KAGGLE_IS_COMPETITION_RERUN={kicr} KAGGLE_URL_BASE={kub} SCRIPT_IDX=0 python ddp.py

In [None]:
# !KAGGLE_IS_COMPETITION_RERUN={kicr} KAGGLE_URL_BASE={kub} SCRIPT_IDX=1 python ddp.py

In [None]:
import json
import glob
import pandas as pd 
import torch

# =========== Load infer metadata ==========
d= []
fpaths= sorted(glob.glob("/tmp/working/*.json"))
for fpath in fpaths:

    # Load json        
    with open(fpath, "r") as f:
        metadata= json.load(f)
        
    metadata= metadata | {"fpath": fpath.replace(".json", ".pt")}
    d.append(metadata)
    
df = pd.DataFrame(d)

# Sanity check
# DDP might predict 2x for same tomo
df = df.groupby('tomo_id', as_index=False).agg({
    'tomo_id': 'first',
    'script_id': 'first',
    'z_shape': 'first',
    'y_shape': 'first',
    'x_shape': 'first',
    'fpath': lambda x: list(set(x)),
})
display(df)

In [None]:
from tqdm import tqdm
import numpy as np


# ========== Ensemble volumes ===========
sub_rows= []
for i, row in tqdm(df.iterrows(), total=len(df)):
    row= row.to_dict()
    
    # Ensemble
    arr= []
    for f in row["fpath"]:
        tmp= torch.load(f, weights_only=False)
        arr.append(tmp)
    arr= torch.stack(arr, axis=0)
    
    arr= torch.sum(arr, axis=0) # Mean ensemble
    # arr= arr.median(dim=0).values # Median ensemble
    # arr= (arr.clamp(min=0.01).log().mean(dim=0)).exp() # Geometric mean ensemble

    # Argmax
    coords = torch.argmax(arr)
    coords = torch.unravel_index(coords, arr.shape)
    coords = (
        (coords[0].item() + 0.5) / arr.shape[0],
        (coords[1].item() + 0.5) / arr.shape[1],
        (coords[2].item() + 0.5) / arr.shape[2], 
        )

    # Add
    sub_rows.append({
        "tomo_id": row["tomo_id"],
        "z": coords[0] * row["z_shape"],
        "y": coords[1] * row["y_shape"], 
        "x": coords[2] * row["x_shape"], 
        "max": torch.max(arr).item(),
    })
    
sub= pd.DataFrame(sub_rows)
display(sub)

## Format Sub

In [None]:
# Apply Threshold
cutoff= sub['max'].quantile(QUANTILE_THRESHOLD)
sub.loc[sub["max"] <= cutoff, ["z", "y", "x"]]= -1.0
print("="*25)
print("threshold:", QUANTILE_THRESHOLD)
print("cutoff:", cutoff)
print("="*25)

# Format sub
col_map= {
    "z": "Motor axis 0",
    "y": "Motor axis 1",
    "x": "Motor axis 2",
}
sub= sub.rename(columns=col_map)
sub= sub[["tomo_id", "Motor axis 0", "Motor axis 1", "Motor axis 2"]]
sub.to_csv("submission.csv", index=False)

print(sub)

clean_working("/tmp/images/")
clean_working("/kaggle/working/")
