In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from PIL import Image
from scipy.ndimage import zoom
from torch.amp import autocast, GradScaler
from sklearn.metrics import fbeta_score
from tqdm import tqdm

In [2]:
class MotorVolumeDataset(Dataset):
    def __init__(self, root_dir, label_csv, is_train=True, tomo_filter=None):
        self.root_dir = root_dir
        self.label_df = pd.read_csv(label_csv)
        self.is_train = is_train
        self.transform = None 

        self.data_info = []

        exclude_tomos = {'tomo_003acc', 'tomo_00e047', 'tomo_01a877'}
        target_shape = [128, 704, 704] # Define target_shape once
        
        for tomo_id in sorted(os.listdir(root_dir)):
            if is_train and tomo_id in exclude_tomos:
                continue
            if not is_train and tomo_id not in exclude_tomos:
                continue
            if tomo_filter is not None and tomo_id not in tomo_filter:
                continue

            tomo_path = os.path.join(root_dir, tomo_id)
            
            slices_names = sorted(os.listdir(tomo_path))
            if not slices_names: # Skip empty folders
                continue

            # Load just the first slice to get H, W for calculating zoom factors
            first_slice_path = os.path.join(tomo_path, slices_names[0])
            first_slice_img = Image.open(first_slice_path)
            original_h, original_w = first_slice_img.size
            original_d = len(slices_names) # Depth is the number of slices

            original_shape = (original_d, original_h, original_w)

            zoom_factors = [
                target_shape[0] / original_shape[0],
                target_shape[1] / original_shape[1],
                target_shape[2] / original_shape[2],
            ]

            # Get GT coordinates 
            motors = self.label_df[self.label_df['tomo_id'] == tomo_id]
            valid = motors[motors['Motor axis 0'] != -1]
            coords = valid[['Motor axis 0', 'Motor axis 1', 'Motor axis 2']].values

            # Scale coords to resized shape
            z_scale, y_scale, x_scale = zoom_factors
            scaled_coords = coords * np.array([z_scale, y_scale, x_scale])

            if len(scaled_coords) > 0:
                label = torch.tensor(scaled_coords, dtype=torch.float32)
                has_motor = torch.tensor([1.0], dtype=torch.float32)
            else:
                label = torch.tensor([[-1, -1, -1]], dtype=torch.float32)
                has_motor = torch.tensor([0.0], dtype=torch.float32)
            
            # Store all necessary info for __getitem__ to load later
            self.data_info.append({
                'tomo_id': tomo_id,
                'tomo_path': tomo_path, # Store path to load slices later
                'slices_names': slices_names, # Store slice names for loading order
                'original_shape': original_shape,
                'zoom_factors': zoom_factors,
                'gt_coords': label,
                'has_motor': has_motor
            })

    def __len__(self):
        return len(self.data_info)

    def __getitem__(self, idx):
        info = self.data_info[idx]

        # 1. Load volume slices
        volume = [np.array(Image.open(os.path.join(info['tomo_path'], s))) for s in info['slices_names']]
        volume = np.stack(volume, axis=0)  # [D, H, W]

        # 2. Resize
        volume_resized = zoom(volume, info['zoom_factors'], order=1)
        volume_tensor = torch.tensor(volume_resized).unsqueeze(0).float() / 255.0 # [1, D, H, W]

        # Calculate inverse scale factors
        original_d, original_h, original_w = info['original_shape']
        z_factor, y_factor, x_factor = info['zoom_factors']
        inverse_scale = [
            original_d / (original_d * z_factor), # which simplifies to 1 / z_factor
            original_h / (original_h * y_factor), # which simplifies to 1 / y_factor
            original_w / (original_w * x_factor), # which simplifies to 1 / x_factor
        ]

        return {
            'volume': volume_tensor,
            'gt_coords': info['gt_coords'],
            'has_motor': info['has_motor'],
            'tomo_id': info['tomo_id'],
            'inverse_scale': inverse_scale
        }


In [3]:
# ===================== Model =====================
class MotorNet3D(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Conv3d(1, 8, 3, padding=1), nn.ReLU(), nn.MaxPool3d(2),
            nn.Conv3d(8, 16, 3, padding=1), nn.ReLU(), nn.MaxPool3d(2),
            nn.Conv3d(16, 32, 3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool3d(1)
        )
        self.fc_coord = nn.Linear(32, 3)  # (z, y, x)
        self.fc_conf = nn.Linear(32, 1)

    def forward(self, x):  # x: [B, 1, D, H, W]
        x = self.backbone(x)
        x = x.view(x.size(0), -1)
        coords = self.fc_coord(x)  # [B, 3]
        conf_logits = self.fc_conf(x) # [B, 1]
        return coords, conf_logits

In [4]:
# ===================== Loss =====================
def motor_loss(pred_coord, pred_conf_logits, gt_coords, has_motor, alpha=1.0, beta=2.0):
    B = pred_coord.size(0)
    device = pred_coord.device
    coord_loss = torch.tensor(0.0, device=device)

    for i in range(B):
        if has_motor[i] == 1:
            gt = gt_coords[i].to(device)  # shape: [N, 3]
            pred = pred_coord[i]          # shape: [3]
            dists = torch.norm(gt - pred.unsqueeze(0), dim=1)
            closest = gt[torch.argmin(dists)]
            coord_loss += F.mse_loss(pred, closest)

    coord_loss = coord_loss / has_motor.sum().clamp(min=1)
    conf_loss = F.binary_cross_entropy_with_logits(pred_conf_logits.squeeze(), has_motor.squeeze())
    return alpha * coord_loss + beta * conf_loss

In [5]:
# ===================== Run =====================
data_root = '/kaggle/input/byu-locating-bacterial-flagellar-motors-2025/train'
label_path = '/kaggle/input/byu-locating-bacterial-flagellar-motors-2025/train_labels.csv'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_dataset = MotorVolumeDataset(
    root_dir= data_root,
    label_csv= label_path,
    is_train=True
)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, pin_memory=True)
print(f"Number of training batches: {len(train_loader)}")
model = MotorNet3D().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scaler = GradScaler()  # for AMP

for epoch in range(3):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader):
        x = batch['volume'].to(device)        # [B, 1, D, H, W]
        gt_coords = batch['gt_coords']        # [B, N, 3]
        has_motor = batch['has_motor'].to(device)

        optimizer.zero_grad()
        with autocast('cuda' if torch.cuda.is_available() else 'cpu'):
            pred_coords, pred_conf = model(x)
            loss = motor_loss(pred_coords, pred_conf, gt_coords, has_motor)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}: Loss = {total_loss / len(train_loader):.4f}")
    torch.cuda.empty_cache()
torch.save(model.state_dict(), "motor_net.pth")

Number of training batches: 645


100%|██████████| 645/645 [2:06:53<00:00, 11.80s/it]


Epoch 1: Loss = 15998.0991


100%|██████████| 645/645 [1:56:50<00:00, 10.87s/it]


Epoch 2: Loss = 8971.5662


100%|██████████| 645/645 [1:51:47<00:00, 10.40s/it]

Epoch 3: Loss = 8869.7666





In [6]:
# Path to the test data folder and label CSV file
test_root_dir = '/kaggle/input/byu-locating-bacterial-flagellar-motors-2025/test'        
label_csv = '/kaggle/input/byu-locating-bacterial-flagellar-motors-2025/train_labels.csv'   

# Initialize the test dataset, note is_train=False so it loads only test tomos
test_dataset = MotorVolumeDataset(
    root_dir=test_root_dir,
    label_csv=label_csv,
    is_train=False,    # load only test tomos
    tomo_filter=None   
)

# Create DataLoader with batch_size=1, since volumes are large
test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    pin_memory=True
)

In [7]:
submission_rows = []

model.eval()
with torch.no_grad():
    for batch in tqdm(test_loader):
        volume = batch['volume'].to(device)           # [1, 1, D, H, W]
        tomo_ids = batch['tomo_id']                    # list of length 1

        pred_coord, pred_conf = model(volume)          # [1, 3], [1, 1]
        pred_coord = pred_coord.cpu().squeeze().numpy()
        pred_conf = pred_conf.cpu().item()

        inverse_scale = batch['inverse_scale'] 
        original_pred_coord = pred_coord * np.array([s.item() for s in inverse_scale])

        # If confidence is low, predict -1
        if pred_conf < 0.5:
          original_pred_coord = [-1, -1, -1]
            

        submission_rows.append({
            'tomo_id': tomo_ids[0],
            'Motor axis 0': original_pred_coord[0],
            'Motor axis 1': original_pred_coord[1],
            'Motor axis 2': original_pred_coord[2],
         })

submission_df = pd.DataFrame(submission_rows)
submission_df.to_csv('submission.csv', index=False)

100%|██████████| 3/3 [00:41<00:00, 13.98s/it]


In [8]:
submission = pd.read_csv('/kaggle/working/submission.csv')

In [9]:
submission

Unnamed: 0,tomo_id,Motor axis 0,Motor axis 1,Motor axis 2
0,tomo_003acc,-1,-1,-1
1,tomo_00e047,-1,-1,-1
2,tomo_01a877,-1,-1,-1
