In [None]:
import torchvision.models as models
import torch
import torch.nn as nn
from collections import OrderedDict

depth = 216

def inflate_conv(conv2d):
    conv3d = nn.Conv3d(
        in_channels=conv2d.in_channels,
        out_channels=conv2d.out_channels,
        kernel_size=(conv2d.kernel_size[0], conv2d.kernel_size[0], conv2d.kernel_size[1]),
        stride=(conv2d.stride[0], conv2d.stride[0], conv2d.stride[1]),
        padding=(conv2d.padding[0], conv2d.padding[0], conv2d.padding[1]),
        bias=conv2d.bias is not None
    )
    weight_2d = conv2d.weight.data
    weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, depth, 1, 1) / depth
    conv3d.weight.data = weight_3d
    if conv2d.bias is not None:
        conv3d.bias.data = conv2d.bias.data
    return conv3d

model_2d = models.densenet121(pretrained=True)
model_3d = models.densenet121(pretrained=False)

def inflate_densenet(model_2d):
    model_3d = models.densenet121(pretrained=False)
    new_state_dict = OrderedDict()

    for name, module in model_2d.named_modules():
        if isinstance(module, nn.Conv2d):
            conv3d = inflate_conv(module)
            for key, value in conv3d.state_dict().items():
                full_key = name + '.' + key
                new_state_dict[full_key] = value
        elif isinstance(module, nn.BatchNorm2d):
            for key, value in module.state_dict().items():
                full_key = name + '.' + key
                new_state_dict[full_key] = value
        elif isinstance(module, nn.Linear):
            for key, value in module.state_dict().items():
                full_key = name + '.' + key
                new_state_dict[full_key] = value

    model_3d.load_state_dict(new_state_dict, strict=True)
    return model_3d

model_3d = inflate_densenet(model_2d)
print(model_3d)

x = torch.randn(1, 1, 216, 512, 512)
with torch.no_grad():
    output = model_3d.features.conv0(x)
print("Output shape from first 3D conv layer:", output.shape)




: 

In [2]:
import torchvision.models as models
model_2d = models.densenet121(pretrained=True)
model_2d



DenseNet(
  (features): Sequential(
    (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace=True)
    (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (denseblock1): _DenseBlock(
      (denselayer1): _DenseLayer(
        (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer2): _DenseLayer(
        (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu

In [5]:
import nibabel as nib
def load_ct_scan(file_path):
    ct_img = nib.load(file_path)
    ct_data = ct_img.get_fdata()
    return ct_data, ct_img.affine


img_path = '3702_left_knee.nii.gz'
ct_data, _ = load_ct_scan(img_path)
ct_data.shape

(512, 512, 216)