# 1.1 Generate validation and test set

In [None]:
import os
import shutil
import natsort
import random

# Random sampling, generate validation and test sets.
def random_sample(num_class, val_rate, test_rate):
    train_dir_sub = f'./data/train/subtomogram_mrc'
    val_dir_sub = f'./data/val/subtomogram_mrc'
    test_dir_sub = f'./data/test/subtomogram_mrc'
    train_dir_json = f'./data/train/json_label'
    val_dir_json = f'./data/val/json_label'
    test_dir_json = f'./data/test/json_label'
    if not os.path.exists(val_dir_sub):
        os.makedirs(val_dir_sub, exist_ok=True)
    if not os.path.exists(val_dir_json):
        os.makedirs(val_dir_json, exist_ok=True)
    if not os.path.exists(test_dir_sub):
        os.makedirs(test_dir_sub, exist_ok=True)
    if not os.path.exists(test_dir_json):
        os.makedirs(test_dir_json, exist_ok=True)
    all_files = os.listdir(train_dir_sub)
    sorted_files = natsort.natsorted(all_files)
    # print(sorted_files)
    for i in range(10):
        # print(sorted_files[i*500:(i+1)*500])
        one_class_files = sorted_files[i * num_class:(i + 1) * num_class]
        val_files = random.sample(one_class_files, int(num_class * val_rate))
        one_class_files_rest = [x for x in one_class_files if x not in val_files]
        test_files = random.sample(one_class_files_rest, int(num_class * test_rate))
        # if i == 1: print(val_files)
        for val_file in val_files:
            val_file_json = val_file.replace('tomotarget', 'target').replace('mrc', 'json')
            source_path_sub = os.path.join(train_dir_sub, val_file)
            source_path_json = os.path.join(train_dir_json, val_file_json)
            destination_path_sub = os.path.join(val_dir_sub, val_file)
            destination_path_json = os.path.join(val_dir_json, val_file_json)
            shutil.move(source_path_sub, destination_path_sub)
            shutil.move(source_path_json, destination_path_json)

        for test_file in test_files:
            test_file_json = test_file.replace('tomotarget', 'target').replace('mrc', 'json')
            source_path_sub = os.path.join(train_dir_sub, test_file)
            source_path_json = os.path.join(train_dir_json, test_file_json)
            destination_path_sub = os.path.join(test_dir_sub, test_file)
            destination_path_json = os.path.join(test_dir_json, test_file_json)
            shutil.move(source_path_sub, destination_path_sub)
            shutil.move(source_path_json, destination_path_json)

        print(f'class {i} is done')

num_class = 500  # the number of each class
val_rate = 0.2 # the rate of val
test_rate = 0.1 # the rate of test
random_sample(num_class, val_rate, test_rate)

# 1.2 Data Transforms

In [1]:
# https://github.com/Shadowalker1995/MOCO-Subtomograms/blob/master/CustomTransforms.py
import numpy as np
import torch
import random


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, image):
        image_tensor = torch.Tensor(image.copy())
        return image_tensor


class Normalize3D(object):
    """Normalize a tensor voxel with mean and standard deviation.
    Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
    channels, this transform will normalize each channel of the input
    ``torch.*Tensor`` i.e.,
    ``output[channel] = (input[channel] - mean[channel]) / std[channel]``

    .. note::
        This transform acts out of place, i.e., it does not mutate the input tensor.

    Args:
        mean (sequence): Sequence of means for each channel.
        std (sequence): Sequence of standard deviations for each channel.
        inplace(bool,optional): Bool to make this operation in-place.

    """

    def __init__(self, mean, std, inplace=False):
        self.mean = mean
        self.std = std
        self.inplace = inplace

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor voxel of size (C, D, H, W) to be normalized.

        Returns:
            Tensor: Normalized Tensor voxel.
        """
        return normalize3D(tensor, self.mean, self.std, self.inplace)

    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)


def normalize3D(tensor, mean, std, inplace=False):
    """Normalize a tensor voxel with mean and standard deviation.

    .. note::
        This transform acts out of place by default, i.e., it does not mutates the input tensor.

    Args:
        tensor (Tensor): Tensor voxel of size (C, D, H, W) to be normalized.
        mean (sequence): Sequence of means for each channel.
        std (sequence): Sequence of standard deviations for each channel.
        inplace(bool,optional): Bool to make this operation inplace.

    Returns:
        Tensor: Normalized Tensor voxel.
    """
    if not torch.is_tensor(tensor):
        raise TypeError('tensor should be a torch tensor. Got {}.'.format(type(tensor)))

    if tensor.ndimension() != 4:
        raise ValueError('Expected tensor to be a tensor voxel of size (C, D, H, W). Got tensor.size() = '
                         '{}.'.format(tensor.size()))

    if not inplace:
        tensor = tensor.clone()

    dtype = tensor.dtype
    mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
    std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
    if (std == 0).any():
        raise ValueError('std evaluated to zero after conversion to {}, leading to division by zero.'.format(dtype))
    if mean.ndim == 1:
        mean = mean[:, None, None, None]
    if std.ndim == 1:
        std = std[:, None, None, None]
    tensor.sub_(mean).div_(std)
    return tensor

# 1.3 Data load

In [2]:
# Class labels, if using a custom dataset, please customize the class labels.
class_10 = [
    "1bxn",
    "1f1b",
    "1yg6",
    "2byu",
    "2h12",
    "2ldb",
    "3gl1",
    "3hhb",
    "4d4r",
    "6t3e",
]

In [None]:
# https://github.com/Shadowalker1995/MOCO-Subtomograms/blob/master/Custom_CryoET_DataLoader.py
import os
import torch
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import natsort
import mrcfile
import json
import random


def mapping_types(num_classes):
    if num_classes == 10:
        labels = class_10
    # elif num_classes == **:
    #     labels = ***

    label_to_target = {label: idx for idx, label in enumerate(labels)}
    return label_to_target

# Dataset
class CryoETDatasetLoader(Dataset):
    def __init__(self, root_dir, json_dir, transform=None):
        self.root_dir = root_dir
        self.json_dir = json_dir
        self.transform = transform
        all_imgs = os.listdir(root_dir)
        all_jsons = os.listdir(json_dir)
        self.total_imgs = natsort.natsorted(all_imgs)
        self.total_jsons = natsort.natsorted(all_jsons)
        print(f'{len(self.total_imgs)}, vs {len(self.total_jsons)}')
        assert (len(self.total_imgs) == len(self.total_jsons))
        num_classes = 10
        self.label_to_target = mapping_types(num_classes)

    def __len__(self):
        return len(self.total_jsons)

    def __getitem__(self, idx):
        path_img = os.path.join(self.root_dir, self.total_imgs[idx])
        path_json = os.path.join(self.json_dir, self.total_jsons[idx])

        with mrcfile.open(path_img, mode='r+', permissive=True) as mrc:
            mrc_img = mrc.data
            if mrc_img is None:
                print(path_img)
            try:
                mrc_img = mrc_img.astype(np.float32).transpose((2, 1, 0)).reshape((1, 32, 32, 32))
            except:
                print(mrc_img.shape)
                print(path_img)

        with open(path_json) as f:
            mrc_json = json.load(f)

        target = self.label_to_target[mrc_json['name']]

        if self.transform is not None:
            transformed_mrc_img = self.transform(mrc_img)
        else:
            transformed_mrc_img = mrc_img

        return transformed_mrc_img, target

# data path
traindir = os.path.join('data', 'train/subtomogram_mrc')
traindir_json = os.path.join('data', 'train/json_label')
valdir = os.path.join('data', 'val/subtomogram_mrc')
valdir_json = os.path.join('data', 'val/json_label')
testdir = os.path.join('data', 'test/subtomogram_mrc')
testdir_json = os.path.join('data', 'test/json_label')
batch_size = 32
workers = 4

# Initialize normalize
stage_normalize_val = Normalize3D(mean=[0.04725085], std=[13.48426468])
stage_normalize_train = Normalize3D(mean=[0.05964008], std=[13.57436941])

# data augmentation
augmentation_train = [
    ToTensor(),
    stage_normalize_train,
]

augmentation_val = [
    ToTensor(),
    stage_normalize_val,
]

# data load
train_dataset = CryoETDatasetLoader(
    root_dir=traindir, json_dir=traindir_json,
    transform=transforms.Compose(augmentation_train))

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True,
    num_workers=workers, pin_memory=True, drop_last=True)

val_dataset = CryoETDatasetLoader(
    root_dir=valdir, json_dir=valdir_json,
    transform=transforms.Compose(augmentation_val))

val_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False,
    num_workers=workers, pin_memory=True)

test_dataset = CryoETDatasetLoader(
    root_dir=testdir, json_dir=testdir_json,
    transform=transforms.Compose(augmentation_val))

test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=batch_size, shuffle=False,
    num_workers=workers, pin_memory=True)

# 2.1 Model Definition

In [6]:
# https://github.com/Shadowalker1995/MOCO-Subtomograms/tree/master/Encoder3D
# RB3D
import torch
import torch.nn as nn


class Bottleneck(nn.Module):
    def __init__(self):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv3d(32, 16, kernel_size=3, stride=1,
                               padding=1)
        self.conv2 = nn.Conv3d(32, 16, kernel_size=1, stride=1,
                               padding=0)
        self.conv3 = nn.Conv3d(16, 32, kernel_size=3, stride=1,
                               padding=1)
        self.conv4 = nn.Conv3d(32, 16, kernel_size=1, stride=1,
                               padding=0)
        # self.dropout = nn.Dropout(p=0.5)
        self.relu = nn.ReLU()

    def forward(self, x):
        # 32 x 16 x 16 x 16 -> 16 x 16 x 16 x 16
        out_R = self.conv1(x)
        out_R = self.relu(out_R)
        # 32 x 16 x 16 x 16 -> 16 x 16 x 16 x 16
        out_L = self.conv2(x)
        out_L = self.relu(out_L)
        # 16 x 16 x 16 x 16 -> 32 x 16 x 16 x 16
        out_L = self.conv3(out_L)
        out_L = self.relu(out_L)
        # 32 x 16 x 16 x 16 -> 16 x 16 x 16 x 16
        out_L = self.conv4(out_L)
        # out_L = self.dropout(out_L)
        # (16 x 16 x 16 x 16) + (16 x 16 x 16 x 16) -> 32 x 16 x 16 x 16
        out = torch.cat((out_L, out_R), dim=1)
        out = self.relu(out)

        return out


class RB3D(nn.Module):
    def __init__(self, num_classes=10, keepfc=True):
        super(RB3D, self).__init__()
        # dimensions of the 3D image. Channels, Depth, Height, Width
        self.C = 1
        self.D = 32
        self.H = 32
        self.W = 32
        self.keepfc = keepfc

        self.conv1 = nn.Conv3d(self.C, 32, kernel_size=3, stride=1, padding=1)
        self.maxpool = nn.MaxPool3d(kernel_size=2, stride=2, padding=0)
        self.bottleneck_layers = nn.Sequential(*[Bottleneck(), Bottleneck(), Bottleneck(), Bottleneck()])
        self.avgpool = nn.AdaptiveAvgPool3d((2, 2, 2))
        if keepfc:
            self.fc = nn.Linear(256, num_classes)

        # self.dropout = nn.Dropout(p=0.5)
        self.relu = nn.ReLU()
        # self.softmax = nn.Softmax(dim=-1)

        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)

    def forward(self, x):
        # 1 x 32 x 32 x 32 -> 32 x 32 x 32 x 32
        x = self.conv1(x)
        x = self.relu(x)
        # 32 x 32 x 32 x 32 -> 32 x 16 x 16 x 16
        x = self.maxpool(x)
        # 32 x 16 x 16 x 16 -> 32 x 16 x 16 x 16
        x = self.bottleneck_layers(x)
        # 32 x 16 x 16 x 16 -> 32 x 2 x 2 x 2
        x = self.avgpool(x)
        # 32 x 2 x 2 x 2 -> 256
        x = torch.flatten(x, 1)
        # x = self.dropout(x)
        if self.keepfc:
            # 256 -> num_classes
            x = self.fc(x)

        return x

In [7]:
# DSRF3D_v2
class DSRF3D_v2(nn.Module):
    def __init__(self, num_classes=10, keepfc=True):
        super(DSRF3D_v2, self).__init__()
        # dimensions of the 3D image. Channels, Depth, Height, Width
        C = 1
        D = 32
        H = 32
        W = 32
        self.keepfc = keepfc

        self.conv1 = nn.Conv3d(C, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv3d(32, 32, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv3d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv3d(64, 64, kernel_size=3, stride=1, padding=1)
        self.conv5 = nn.Conv3d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv6 = nn.Conv3d(128, 128, kernel_size=3, stride=1, padding=1)

        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))

        if keepfc:
            self.fc = nn.Linear(128, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)

    def forward(self, x):
        # 1 x 32 x 32 x 32 -> 32 x 32 x 32 x 32
        x = F.relu(self.conv1(x))
        # 32 x 32 x 32 x 32 -> 32 x 32 x 32 x 32
        x = F.relu(self.conv2(x))
        # 32 x 32 x 32 x 32 -> 32 x 16 x 16 x 16
        x = F.max_pool3d(x, kernel_size=2, stride=2, padding=0)
        # 32 x 16 x 16 x 16 -> 64 x 16 x 16 x 16
        x = F.relu(self.conv3(x))
        # 64 x 16 x 16 x 16 -> 64 x 16 x 16 x 16
        x = F.relu(self.conv4(x))
        # 64 x 16 x 16 x 16 -> 64 x 8 x 8 x 8
        x = F.max_pool3d(x, kernel_size=2, stride=2, padding=0)
        # 64 x 8 x 8 x 8 -> 128 x 8 x 8 x 8
        x = F.relu(self.conv5(x))
        # 128 x 8 x 8 x 8 -> 128 x 8 x 8 x 8
        x = F.relu(self.conv6(x))
        # 128 x 8 x 8 x 8 -> 128 x 1 x 1 x 1
        x = self.avgpool(x)
        # 128 x 1 x 1 x 1 -> 128
        x = torch.flatten(x, 1)
        if self.keepfc:
            # 128 -> num_classes
            x = self.fc(x)

        return x

In [8]:
# YOPO
import torch
import torch.nn as nn
import torch.nn.functional as F


class BasicConv3d(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv3d, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, **kwargs)
        self.elu = nn.ELU(alpha=1.0)
        self.bn = nn.BatchNorm3d(out_channels, eps=0.001, momentum=0.99)
        # self.avgpool = nn.AdaptiveAvgPool3d(1)

    def forward(self, x):
        # in_channels x ... -> out_channels x ...
        x = self.conv(x)
        # out_channels x ... -> out_channels x ...
        x = self.elu(x)
        # out_channels x ... -> out_channels x ...
        x = self.bn(x)
        return x


class YOPO(nn.Module):
    def __init__(self, num_classes=10, keepfc=True):
        super(YOPO, self).__init__()
        # dimensions of the 3D image. Channels, Depth, Height, Width
        C = 1
        D = 32
        H = 32
        W = 32
        # self.keepfc = keepfc

        self.avgpool = nn.AdaptiveAvgPool3d(1)

        # 1 x 32^3 -> 4 x 30^3
        self.conv_1_1 = BasicConv3d(1, 4, kernel_size=3, padding='valid')
        # 4 x 30^3 -> 5 x 28^3
        self.conv_1_2 = BasicConv3d(4, 5, kernel_size=3, padding='valid')
        # 5 x 28^3 -> 6 x 26^3
        self.conv_1_3 = BasicConv3d(5, 6, kernel_size=3, padding='valid')
        # 6 x 26^3 -> 7 x 24^3
        self.conv_1_4 = BasicConv3d(6, 7, kernel_size=3, padding='valid')
        # 7 x 24^3 -> 8 x 22^3
        self.conv_1_5 = BasicConv3d(7, 8, kernel_size=3, padding='valid')
        # 8 x 22^3 -> 9 x 20^3
        self.conv_1_6 = BasicConv3d(8, 9, kernel_size=3, padding='valid')

        # 1 x 32^3 -> 3 x 29^3
        self.conv_2_1 = BasicConv3d(1, 3, kernel_size=4, padding='valid')
        # 3 x 29^3 -> 4 x 26^3
        self.conv_2_2 = BasicConv3d(3, 4, kernel_size=4, padding='valid')
        # 4 x 26^3 -> 5 x 23^3
        self.conv_2_3 = BasicConv3d(4, 5, kernel_size=4, padding='valid')
        # 5 x 23^3 -> 6 x 20^3
        self.conv_2_4 = BasicConv3d(5, 6, kernel_size=4, padding='valid')

        # 1 x 32^3 -> 2 x 28^3
        self.conv_3_1 = BasicConv3d(1, 2, kernel_size=5, padding='valid')
        # 2 x 28^3 -> 3 x 24^3
        self.conv_3_2 = BasicConv3d(2, 3, kernel_size=5, padding='valid')
        # 3 x 24^3 -> 4 x 20^3
        self.conv_3_3 = BasicConv3d(3, 4, kernel_size=5, padding='valid')

        # 1 x 32^3 -> 1 x 26^3
        self.conv_4_1 = BasicConv3d(1, 1, kernel_size=7, padding='valid')
        # 1 x 26^3 -> 2 x 20^3
        self.conv_4_2 = BasicConv3d(1, 2, kernel_size=7, padding='valid')

        # (9+6+4+2)=21 x 20^3 -> 10 x 18^3
        self.conv_5_1 = BasicConv3d(21, 10, kernel_size=3, padding='valid')
        # 10 x 18^3 -> 11 x 16^3
        self.conv_5_2 = BasicConv3d(10, 11, kernel_size=3, padding='valid')
        # 11 x 16^3 -> 12 x 14^3
        self.conv_5_3 = BasicConv3d(11, 12, kernel_size=3, padding='valid')
        # 12 x 14^3 -> 13 x 12^3
        self.conv_5_4 = BasicConv3d(12, 13, kernel_size=3, padding='valid')
        # 13 x 12^3 -> 14 x 10^3
        self.conv_5_5 = BasicConv3d(13, 14, kernel_size=3, padding='valid')
        # 14 x 10^3 -> 15 x 8^3
        self.conv_5_6 = BasicConv3d(14, 15, kernel_size=3, padding='valid')
        # 15 x 8^3 -> 16 x 6^3
        self.conv_5_7 = BasicConv3d(15, 16, kernel_size=3, padding='valid')
        # 16 x 6^3 -> 17 x 4^3
        self.conv_5_8 = BasicConv3d(16, 17, kernel_size=3, padding='valid')
        # 17 x 4^3 -> 18 x 2^3
        self.conv_5_9 = BasicConv3d(17, 18, kernel_size=3, padding='valid')

        # (9+6+4+2)=21 x 20^3 -> 7 x 17^3
        self.conv_6_1 = BasicConv3d(21, 7, kernel_size=4, padding='valid')
        # 7 x 17^3 -> 8 x 14^3
        self.conv_6_2 = BasicConv3d(7, 8, kernel_size=4, padding='valid')
        # 8 x 14^3 -> 9 x 11^3
        self.conv_6_3 = BasicConv3d(8, 9, kernel_size=4, padding='valid')
        # 9 x 11^3 -> 10 x 8^3
        self.conv_6_4 = BasicConv3d(9, 10, kernel_size=4, padding='valid')
        # 10 x 8^3 -> 11 x 5^3
        self.conv_6_5 = BasicConv3d(10, 11, kernel_size=4, padding='valid')
        # 11 x 5^3 -> 12 x 2^3
        self.conv_6_6 = BasicConv3d(11, 12, kernel_size=4, padding='valid')

        # (9+6+4+2)=21 x 20^3 -> 5 x 16^3
        self.conv_7_1 = BasicConv3d(21, 5, kernel_size=5, padding='valid')
        # 5 x 16^3 -> 6 x 12^3
        self.conv_7_2 = BasicConv3d(5, 6, kernel_size=5, padding='valid')
        # 6 x 12^3 -> 7 x 8^3
        self.conv_7_3 = BasicConv3d(6, 7, kernel_size=5, padding='valid')
        # 7 x 8^3 -> 8 x 4^3
        self.conv_7_4 = BasicConv3d(7, 8, kernel_size=5, padding='valid')

        # (9+6+4+2)=21 x 20^3 -> 3 x 14^3
        self.conv_8_1 = BasicConv3d(21, 3, kernel_size=6, padding='valid')
        # 3 x 14^3 -> 4 x 8^3
        self.conv_8_2 = BasicConv3d(3, 4, kernel_size=6, padding='valid')
        # 4 x 8^3-> 5 x 2^3
        self.conv_8_3 = BasicConv3d(4, 5, kernel_size=6, padding='valid')

        # self.elu = nn.ELU(alpha=1.0)
        self.bn1 = nn.BatchNorm1d(290, eps=0.001, momentum=0.99)
        self.linear1 = nn.Linear(290, 256)
        self.bn2 = nn.BatchNorm1d(256, eps=0.001, momentum=0.99)
        self.linear2 = nn.Linear(546, 128)
        self.bn3 = nn.BatchNorm1d(128, eps=0.001, momentum=0.99)
        self.linear3 = nn.Linear(674, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)

    def forward(self, x):
        # 1 x 32^3 -> 4 x 30^3
        c_1 = self.conv_1_1(x)
        # 4 x 30^3 -> 4
        m_1_1 = torch.flatten(self.avgpool(c_1), 1)
        # 4 x 30^3 -> 5 x 28^3
        c_1 = self.conv_1_2(c_1)
        # 5 x 28^3 -> 5
        m_1_2 = torch.flatten(self.avgpool(c_1), 1)
        # 5 x 28^3 -> 6 x 26^3
        c_1 = self.conv_1_3(c_1)
        # 6 x 26^3 -> 6
        m_1_3 = torch.flatten(self.avgpool(c_1), 1)
        # 6 x 26^3 -> 7 x 24^3
        c_1 = self.conv_1_4(c_1)
        # 7 x 24^3 -> 7
        m_1_4 = torch.flatten(self.avgpool(c_1), 1)
        # 7 x 24^3 -> 8 x 22^3
        c_1 = self.conv_1_5(c_1)
        # 8 x 22^3 -> 8
        m_1_5 = torch.flatten(self.avgpool(c_1), 1)
        # 8 x 22^3 -> 9 x 20^3
        c_1 = self.conv_1_6(c_1)
        # 9 x 20^3-> 9
        m_1_6 = torch.flatten(self.avgpool(c_1), 1)

        # 1 x 32^3 -> 3 x 29^3
        c_2 = self.conv_2_1(x)
        # 3 x 29^3 -> 3
        m_2_1 = torch.flatten(self.avgpool(c_2), 1)
        # 3 x 29^3 -> 4 x 26^3
        c_2 = self.conv_2_2(c_2)
        # 4 x 26^3 -> 4
        m_2_2 = torch.flatten(self.avgpool(c_2), 1)
        # 4 x 26^3 -> 5 x 23^3
        c_2 = self.conv_2_3(c_2)
        # 5 x 23^3 -> 5
        m_2_3 = torch.flatten(self.avgpool(c_2), 1)
        # 5 x 23^3 -> 6 x 20^3
        c_2 = self.conv_2_4(c_2)
        # 6 x 20^3 -> 6
        m_2_4 = torch.flatten(self.avgpool(c_2), 1)

        # 1 x 32^3 -> 2 x 28^3
        c_3 = self.conv_3_1(x)
        # 2 x 28^3 -> 2
        m_3_1 = torch.flatten(self.avgpool(c_3), 1)
        # 2 x 28^3 -> 3 x 24^3
        c_3 = self.conv_3_2(c_3)
        # 3 x 16^3 -> 3
        m_3_2 = torch.flatten(self.avgpool(c_3), 1)
        # 3 x 24^3 -> 4 x 20^3
        c_3 = self.conv_3_3(c_3)
        # 4 x 20^3 -> 4
        m_3_3 = torch.flatten(self.avgpool(c_3), 1)

        # 1 x 32^3 -> 1 x 26^3
        c_4 = self.conv_4_1(x)
        # 1 x 26^3 -> 1
        m_4_1 = torch.flatten(self.avgpool(c_4), 1)
        # 1 x 26^3 -> 2 x 20^3
        c_4 = self.conv_4_2(c_4)
        # 2 x 20^3 -> 2
        m_4_2 = torch.flatten(self.avgpool(c_4), 1)

        # (9+6+4+2)=21 x 20^3
        x = torch.cat((c_1, c_2, c_3, c_4), dim=1)

        # 21 x 20^3 -> 10 x 18^3
        c_5 = self.conv_5_1(x)
        # 10 x 18^3 -> 10
        m_5_1 = torch.flatten(self.avgpool(c_5), 1)
        # 10 x 18^3 -> 11 x 16^3
        c_5 = self.conv_5_2(c_5)
        # 11 x 16^3 -> 11
        m_5_2 = torch.flatten(self.avgpool(c_5), 1)
        # 11 x 16^3 -> 12 x 14^3
        c_5 = self.conv_5_3(c_5)
        # 12 x 14^3 -> 12
        m_5_3 = torch.flatten(self.avgpool(c_5), 1)
        # 12 x 14^3 -> 13 x 12^3
        c_5 = self.conv_5_4(c_5)
        # 13 x 12^3 -> 13
        m_5_4 = torch.flatten(self.avgpool(c_5), 1)
        # 13 x 12^3 -> 14 x 10^3
        c_5 = self.conv_5_5(c_5)
        # 14 x 10^3 -> 14
        m_5_5 = torch.flatten(self.avgpool(c_5), 1)
        # 14 x 10^3 -> 15 x 8^3
        c_5 = self.conv_5_6(c_5)
        # 15 x 8^3 -> 15
        m_5_6 = torch.flatten(self.avgpool(c_5), 1)
        # 15 x 8^3 -> 16 x 6^3
        c_5 = self.conv_5_7(c_5)
        # 16 x 6^3 -> 16
        m_5_7 = torch.flatten(self.avgpool(c_5), 1)
        # 16 x 6^3 -> 17 x 4^3
        c_5 = self.conv_5_8(c_5)
        # 17 x 4^3 -> 17
        m_5_8 = torch.flatten(self.avgpool(c_5), 1)
        # 17 x 4^3 -> 18 x 2^3
        c_5 = self.conv_5_9(c_5)
        # 18 x 2^3 -> 18
        m_5_9 = torch.flatten(self.avgpool(c_5), 1)

        # 21 x 20^3 -> 7 x 17^3
        c_6 = self.conv_6_1(x)
        # 7 x 17^3 -> 7
        m_6_1 = torch.flatten(self.avgpool(c_6), 1)
        # 7 x 17^3 -> 8 x 14^3
        c_6 = self.conv_6_2(c_6)
        # 8 x 14^3 -> 8
        m_6_2 = torch.flatten(self.avgpool(c_6), 1)
        # 8 x 14^3 -> 9 x 11^3
        c_6 = self.conv_6_3(c_6)
        # 9 x 11^3 -> 9
        m_6_3 = torch.flatten(self.avgpool(c_6), 1)
        # 9 x 11^3-> 10 x 8^3
        c_6 = self.conv_6_4(c_6)
        # 10 x 8^3 -> 10
        m_6_4 = torch.flatten(self.avgpool(c_6), 1)
        # 10 x 8^3-> 11 x 5^3
        c_6 = self.conv_6_5(c_6)
        # 11 x 5^3 -> 11
        m_6_5 = torch.flatten(self.avgpool(c_6), 1)
        # 11 x 5^3-> 12 x 2^3
        c_6 = self.conv_6_6(c_6)
        # 12 x 2^3 -> 12
        m_6_6 = torch.flatten(self.avgpool(c_6), 1)

        # 21 x 20^3 -> 5 x 16^3
        c_7 = self.conv_7_1(x)
        # 5 x 16^3 -> 5
        m_7_1 = torch.flatten(self.avgpool(c_7), 1)
        # 5 x 16^3 -> 6 x 12^3
        c_7 = self.conv_7_2(c_7)
        # 6 x 12^3 -> 6
        m_7_2 = torch.flatten(self.avgpool(c_7), 1)
        # 6 x 12^3 -> 7 x 8^3
        c_7 = self.conv_7_3(c_7)
        # 7 x 8^3 -> 7
        m_7_3 = torch.flatten(self.avgpool(c_7), 1)
        # 7 x 8^3 -> 8 x 4^3
        c_7 = self.conv_7_4(c_7)
        # 8 x 4^3 -> 8
        m_7_4 = torch.flatten(self.avgpool(c_7), 1)

        # 21 x 20^3 -> 3 x 14^3
        c_8 = self.conv_8_1(x)
        # 3 x 14^3 -> 3
        m_8_1 = torch.flatten(self.avgpool(c_8), 1)
        # 3 x 14^3 -> 4 x 8^3
        c_8 = self.conv_8_2(c_8)
        # 4 x 8^3 -> 4
        m_8_2 = torch.flatten(self.avgpool(c_8), 1)
        # 4 x 8^3 -> 5 x 2^3
        c_8 = self.conv_8_3(c_8)
        # 5 x 2^3 -> 5
        m_8_3 = torch.flatten(self.avgpool(c_8), 1)

        # ((4+5+6+7+8+9)=39 + (3+4+5+6)=18 + (2+3+4)=9 + (1+2)=3 +
        # (10+11+12+13+14+15+16+17+18)=126 + (7+8+9+10+11+12)=57 + (5+6+7+8)=26 + (3+4+5)=12)=290
        m = torch.cat(
            (m_1_1, m_1_2, m_1_3, m_1_4, m_1_5, m_1_6,
             m_2_1, m_2_2, m_2_3, m_2_4,
             m_3_1, m_3_2, m_3_3,
             m_4_1, m_4_2,
             m_5_1, m_5_2, m_5_3, m_5_4, m_5_5, m_5_6, m_5_7, m_5_8, m_5_9,
             m_6_1, m_6_2, m_6_3, m_6_4, m_6_5, m_6_6,
             m_7_1, m_7_2, m_7_3, m_7_4,
             m_8_1, m_8_2, m_8_3),
            dim=1)

        fc1 = self.bn1(m)
        # 290 -> 256
        m = self.linear1(fc1)
        m = F.elu(m)
        fc2 = self.bn2(m)
        # 290+256=546
        fc2 = torch.cat((fc1, fc2), dim=1)
        # 546 -> 128
        m = self.linear2(fc2)
        m = F.elu(m)
        fc3 = self.bn3(m)
        # 546+128=674
        fc3 = torch.cat((fc2, fc3), dim=1)
        # 674 -> num_classes
        m = self.linear3(fc3)

        return m

# 2.2 Train and Val

In [11]:
# https://github.com/Shadowalker1995/MOCO-Subtomograms/blob/master/main_moco.py
import math
import os
import random
import shutil
import time
import warnings

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms

# Train for one epoch
def train(train_loader, model, criterion, optimizer, epoch):
    # Display metrics
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':6.2f')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, top1, top5],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    print_freq = 10
    end = time.time()
    # start to train
    for i, (images, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if gpu is not None:
            images = images.cuda(gpu, non_blocking=True)
            target = target.cuda(gpu, non_blocking=True)

        # compute output
        output = model(images)
        loss = criterion(output, target)
        # acc1/acc5 are (K+1)-way contrast classifier accuracy
        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % print_freq == 0:
            progress.display(i)


def validate(val_loader, model, criterion):
    # switch to evaluate mode
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':6.2f')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, data_time, losses, top1, top5],
        prefix="Val: ")
    model.eval()
    print_freq = 10
    end = time.time()
    # start evaluation
    with torch.no_grad():
        for i, (images, target) in enumerate(val_loader):
            data_time.update(time.time() - end)
            if gpu is not None:
                images = images.cuda(gpu, non_blocking=True)
                target = target.cuda(gpu, non_blocking=True)

            # compute output
            output = model(images)
            loss = criterion(output, target)
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))
            batch_time.update(time.time() - end)
            end = time.time()
            if i % print_freq == 0:
                progress.display(i)
    return losses.avg


def save_checkpoint(state, is_best, filename='checkpoint/model_best.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'checkpoint/model_best.pth.tar')


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'


def adjust_learning_rate(lr,schedule,cos,optimizer, epoch,epochs):
    """Decay the learning rate based on schedule"""
    if cos:  # cosine lr schedule
        lr *= 0.5 * (1. + math.cos(math.pi * epoch / epochs))
    else:  # stepwise lr schedule
        for milestone in schedule:
            lr *= 0.1 if epoch >= milestone else 1.
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [None]:
# Define the model, you can also customize the model.
model_dictionary = {'RB3D': RB3D, 'DSRF3D_v2': DSRF3D_v2,
                    'YOPO': YOPO}
# Training parameters
arch = 'RB3D'
lr = 0.003
momentum = 0.9
weight_decay = 1e-4
epochs = 600
cos = False # cosine lr schedule
schedule = [300,400,500]  # stepwise lr schedule
gpu = 1 # choose gpu to use

# Adjust num_classes based on the data.
model = model_dictionary[arch](num_classes=10)
model.cuda(gpu)
criterion = nn.CrossEntropyLoss().cuda(gpu)
optimizer = torch.optim.SGD(model.parameters(), lr,
                                momentum=momentum,
                                weight_decay=weight_decay)

best_loss = 10
for epoch in range(epochs):
    adjust_learning_rate(lr,schedule,cos,optimizer, epoch,epochs)

    # train for one epoch
    train(train_loader, model, criterion, optimizer, epoch)

    if (epoch + 1) % 20 == 0:
        avg_loss = validate(val_loader, model, criterion, epoch)
        # save the best model
        if avg_loss < best_loss:
            best_loss = avg_loss
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': arch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, is_best=False,
                filename='checkpoint/arch-{}_bs{}_lr{}_best.pth.tar'.format(
                    arch, batch_size, lr))
        # save the last model
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': arch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, is_best=False,
            filename='checkpoint/arch-{}_bs{}_lr{}_last.pth.tar'.format(
                arch, batch_size, lr))

# 2.3 Test

In [None]:
model_dictionary = {'RB3D': RB3D, 'DSRF3D_v2': DSRF3D_v2,
                    'YOPO': YOPO}
# test parameters
arch = 'RB3D'
gpu = 1 # choose gpu to use
checkpoint_path = 'checkpoint/arch-RB3D_bs32_lr0.003_best.pth.tar'

# Adjust num_classes based on the data.
model = model_dictionary[arch](num_classes=10)
checkpoint = torch.load(checkpoint_path, map_location="cpu")
state_dict = checkpoint['state_dict']
model.load_state_dict(state_dict, strict=False)
model.cuda(gpu)
criterion = nn.CrossEntropyLoss().cuda(gpu)
validate(test_loader,model,criterion)