In [1]:
import torch
from monai.networks.nets import UNet
from monai.inferers import sliding_window_inference

### Load test tomograms

In [2]:
import numpy as np
from tqdm import tqdm
import copick

copick_config_path = "configs/test_config.json" 
root = copick.from_file(copick_config_path)

In [3]:
test_dataset = []
for run in tqdm(root.runs):
    tomo = run.get_voxel_spacing(10).get_tomogram('denoised').numpy()
    test_dataset.append({'image': tomo})


100%|██████████| 3/3 [00:00<00:00,  5.02it/s]


In [4]:
len(test_dataset)

3

### Create dataloader for the test dataset

In [5]:
from monai.data import DataLoader, CacheDataset
from monai.transforms import (
    Compose, 
    NormalizeIntensityd,
    EnsureChannelFirstd, 
    Activationsd,
    AsDiscreted
)

# define pre transforms
pre_transforms = Compose(
    [
        EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
        NormalizeIntensityd(keys=["image"]),
    ]
)


test_ds = CacheDataset(data=test_dataset, transform=pre_transforms)
test_loader = DataLoader(test_ds, batch_size=4, shuffle=False, num_workers=4, pin_memory=torch.cuda.is_available())

Loading dataset: 100%|██████████| 3/3 [00:00<00:00, 10.60it/s]


### Load model and weights

In [6]:
from models import CryoETUNet

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

model = CryoETUNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=len(root.pickable_objects)+1,
    channels=(48, 64, 80, 80),
    strides=(2, 2, 1),
    num_res_units=1,
).to(device)

model.load_state_dict(torch.load("models/baseline/best_metric_model.pth", weights_only=True))
model.eval()

cuda


CryoETUNet(
  (model): Sequential(
    (0): ResidualUnit(
      (conv): Sequential(
        (unit0): Convolution(
          (conv): Conv3d(1, 48, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
          (adn): ADN(
            (N): InstanceNorm3d(48, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
            (D): Dropout(p=0.0, inplace=False)
            (A): PReLU(num_parameters=1)
          )
        )
      )
      (residual): Conv3d(1, 48, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
    )
    (1): SkipConnection(
      (submodule): Sequential(
        (0): ResidualUnit(
          (conv): Sequential(
            (unit0): Convolution(
              (conv): Conv3d(48, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
              (adn): ADN(
                (N): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
                (D): Dropout(p=0.0, inplace=False)
                (A): PReLU(

### Inference

In [7]:
def inference(model, input):
    def _compute(input):
        return sliding_window_inference(
            inputs=input,
            roi_size=(96, 96, 96),
            sw_batch_size=4,  # one window is proecessed at a time
            predictor=model,
            overlap=0.5,
        )

    with torch.cuda.amp.autocast():
        return _compute(input)

In [8]:
from monai.data import decollate_batch
from tqdm import tqdm

post_transforms = Compose([
    Activationsd(keys="pred", softmax=True),
    AsDiscreted(keys="pred", argmax=True)
])

predictions = []
with torch.no_grad():
    for data in tqdm(test_loader):
        tomogram = data['image'].to(device)  # only support batch=1 and channel first
        data["pred"] = inference(model, tomogram)
        data = [post_transforms(i) for i in decollate_batch(data)]
        for b in data:
            predictions.append(b['pred'].squeeze(0).numpy(force=True))

  with torch.cuda.amp.autocast():
100%|██████████| 1/1 [00:06<00:00,  6.64s/it]


In [9]:
print(np.unique(predictions[0]))

[0. 1. 3. 4. 5. 6.]


In [10]:
import copick_utils
from copick_utils.segmentation.picks_from_segmentation import picks_from_segmentation


particles = dict()
for po in root.config.pickable_objects:
    particles[po.name] = po.label

    
maxima_filter_size = 10
min_particle_size = 0
max_particle_size = 10
new_session_id = "1"
new_user_id = "paintedFromInferencePicks"
for prediction, run in tqdm(zip(predictions, root.runs)):
    for po in particles.keys():
        if po != "membrane":
            class_label = particles[po]
            picks_from_segmentation(prediction, class_label, maxima_filter_size, min_particle_size, max_particle_size, new_session_id, new_user_id, po, run, voxel_spacing=10)

  Expected `CopickPoint` but got `dict` with value `{'x': 4140.0, 'y': 290.0, 'z': 10.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 1200.0, 'y': 2.0, 'z': 125.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 2400.0, 'y': 1710.0, 'z': 215.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 2390.0, 'y': 1720.0, 'z': 220.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 6280.0, 'y': 2210.0, 'z': 460.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 6288.333333333334, 'y': 2210.0, 'z': 490.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 5580.0, 'y': 1930.0, 'z': 860.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x'

Centroids for label 1 saved successfully.


  Expected `CopickPoint` but got `dict` with value `{'x': 5880.0, 'y': 4340.0, 'z': 30.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 5901.666666666666, 'y': 4340.0, 'z': 30.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 5911.0, 'y': 4334.0, 'z': 41.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 4160.0, 'y': 130.0, 'z': 50.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 1570.0, 'y': 880.0, 'z': 80.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 1580.0, 'y': 920.0, 'z': 90.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 2068.0, 'y': 1880.0, 'z': 282.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 45

Centroids for label 3 saved successfully.


  Expected `CopickPoint` but got `dict` with value `{'x': 2940.0, 'y': 470.0, 'z': 130.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 6120.0, 'y': 5112.0, 'z': 128.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 2930.0, 'y': 460.0, 'z': 140.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 6118.0, 'y': 5100.0, 'z': 138.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 6142.5, 'y': 5087.5, 'z': 147.5}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 5187.777777777778, ...'z': 178.88888888888889}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 5200.0, 'y': 2590.0, 'z': 190.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value

Centroids for label 4 saved successfully.


  Expected `CopickPoint` but got `dict` with value `{'x': 3497.5, 'y': 1887.5, 'z': 107.5}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 3510.0, 'y': 1890.0, 'z': 120.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 740.0, 'y': 1202.5, 'z': 150.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 990.0, 'y': 2090.0, 'z': 310.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 4362.0, 'y': 6060.0, 'z': 348.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 4350.0, 'y': 6080.0, 'z': 380.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 4338.0, 'y': 6080.0, 'z': 392.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 4282.5

Centroids for label 5 saved successfully.


  Expected `CopickPoint` but got `dict` with value `{'x': 1532.0, 'y': 940.0, 'z': 108.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 1520.0, 'y': 940.0, 'z': 120.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 1532.0, 'y': 940.0, 'z': 132.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 2860.0, 'y': 1000.0, 'z': 160.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 2365.0, 'y': 2320.0, 'z': 380.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 6150.0, 'y': 5670.0, 'z': 400.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 6026.0, 'y': 4413.0, 'z': 489.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 6190.0,

Centroids for label 6 saved successfully.
No segmentation with label 9 found.


  Expected `CopickPoint` but got `dict` with value `{'x': 2960.0, 'y': 150.0, 'z': 890.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 2950.0, 'y': 170.0, 'z': 890.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 2948.75, 'y': 153.75, 'z': 902.5}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 2648.0, 'y': 4098.0, 'z': 1170.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 2660.0, 'y': 4110.0, 'z': 1170.0}` - serialized value may not be as expected
  return self.__pydantic_serializer__.to_python(


Centroids for label 1 saved successfully.


  Expected `CopickPoint` but got `dict` with value `{'x': 1810.0, 'y': 4810.0, 'z': 570.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 4640.0, 'y': 1600.0, 'z': 590.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 4240.0, 'y': 5650.0, 'z': 820.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 2590.0, 'y': 320.0, 'z': 840.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 4255.0, 'y': 5660.0, 'z': 830.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 4230.0, 'y': 5660.0, 'z': 850.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 4171.111111111111, ... 'z': 1058.888888888889}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with valu

Centroids for label 3 saved successfully.


  Expected `CopickPoint` but got `dict` with value `{'x': 5760.0, 'y': 6160.0, 'z': 500.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 2758.0, 'y': 830.0, 'z': 612.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 2900.0, 'y': 770.0, 'z': 750.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 5090.0, 'y': 5770.0, 'z': 820.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 5010.0, 'y': 1750.0, 'z': 840.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 5022.5, 'y': 1762.5, 'z': 837.5}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 5910.0, 'y': 3680.0, 'z': 850.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 5920.0

Centroids for label 4 saved successfully.


  Expected `CopickPoint` but got `dict` with value `{'x': 760.0, 'y': 5500.0, 'z': 370.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 770.0, 'y': 5510.0, 'z': 390.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 5818.0, 'y': 6078.0, 'z': 400.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 5818.0, 'y': 6048.0, 'z': 430.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 20.0, 'y': 4450.0, 'z': 440.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 5829.0, 'y': 6049.0, 'z': 456.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 1320.0, 'y': 1430.0, 'z': 500.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 1308.0, 

Centroids for label 5 saved successfully.


  Expected `CopickPoint` but got `dict` with value `{'x': 940.0, 'y': 370.0, 'z': 480.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 940.0, 'y': 382.0, 'z': 492.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 3680.0, 'y': 4330.0, 'z': 700.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 2160.0, 'y': 6020.0, 'z': 890.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 1140.0, 'y': 3300.0, 'z': 910.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 2380.0, 'y': 2390.0, 'z': 1080.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 2380.0, 'y': 2402.0, 'z': 1092.0}` - serialized value may not be as expected
  return self.__pydantic_serializer__.to_python(
2it [01:33, 46.

Centroids for label 6 saved successfully.
No segmentation with label 9 found.


  Expected `CopickPoint` but got `dict` with value `{'x': 5730.0, 'y': 3320.0, 'z': 270.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 170.0, 'y': 2420.0, 'z': 300.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 3370.0, 'y': 2610.0, 'z': 410.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 3970.0, 'y': 2580.0, 'z': 560.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 3958.0, 'y': 2592.0, 'z': 560.0}` - serialized value may not be as expected
  return self.__pydantic_serializer__.to_python(


Centroids for label 1 saved successfully.


  Expected `CopickPoint` but got `dict` with value `{'x': 5770.0, 'y': 3390.0, 'z': 90.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 5780.0, 'y': 3400.0, 'z': 100.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 4062.0, 'y': 2560.0, 'z': 388.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 4050.0, 'y': 2560.0...'z': 398.33333333333337}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 4050.0, 'y': 2560.0, 'z': 420.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 120.0, 'y': 3380.0, 'z': 510.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 2180.0, 'y': 4200.0, 'z': 510.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value

Centroids for label 3 saved successfully.


  Expected `CopickPoint` but got `dict` with value `{'x': 4260.0, 'y': 520.0, 'z': 100.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 4250.0, 'y': 540.0, 'z': 100.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 5048.0, 'y': 710.0, 'z': 262.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 140.0, 'y': 5300.0, 'z': 480.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 4770.0, 'y': 3590.0, 'z': 670.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 980.0, 'y': 5490.0, 'z': 730.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 4264.0, 'y': 2899.0, 'z': 739.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 4271.1111

Centroids for label 4 saved successfully.


  Expected `CopickPoint` but got `dict` with value `{'x': 4260.0, 'y': 500.0, 'z': 30.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 5330.0, 'y': 2010.0, 'z': 30.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 5362.0, 'y': 2032.0, 'z': 40.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 5352.0, 'y': 2020.0, 'z': 62.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 5340.0, 'y': 2010.0, 'z': 70.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 4590.0, 'y': 2270.0, 'z': 190.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 4012.0, 'y': 3490.0, 'z': 268.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 3930.0, 'y

Centroids for label 5 saved successfully.


  Expected `CopickPoint` but got `dict` with value `{'x': 2260.0, 'y': 2380.0, 'z': 260.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 1060.0, 'y': 1860.0, 'z': 360.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 3370.0, 'y': 2820.0, 'z': 450.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 4860.0, 'y': 460.0, 'z': 630.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 4850.0, 'y': 450.0, 'z': 640.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 4040.0, 'y': 6288.333333333334, 'z': 700.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{'x': 3650.0, 'y': 3710.0, 'z': 730.0}` - serialized value may not be as expected
  Expected `CopickPoint` but got `dict` with value `{

Centroids for label 6 saved successfully.
No segmentation with label 9 found.





In [11]:
particles

{'apo-ferritin': 1,
 'beta-galactosidase': 3,
 'ribosome': 4,
 'thyroglobulin': 5,
 'virus-like-particle': 6,
 'membrane': 8,
 'background': 9}

In [21]:
import pandas as pd

inference_points = []
idx = -1

for run in tqdm(root.runs):
    print("------------------------------------")
    print("Experiment name: ", run.name)
    for pname in particles:
        pick = run.get_picks(object_name=pname, user_id=new_user_id)
        if pick:
            for p in pick[0].points:
                idx += 1
                inference_points.append({
                    'id': idx,
                    'experiment': run.name,
                    'particle_type': pname,
                    'x': p['x'],
                    'y': p['y'],
                    'z': p['z']
                })


df = pd.DataFrame(inference_points)
df.to_csv("submission.csv", index=False)
    

100%|██████████| 3/3 [00:00<00:00, 5899.16it/s]

------------------------------------
Experiment name:  TS_5_4
------------------------------------
Experiment name:  TS_69_2
------------------------------------
Experiment name:  TS_6_4



