In [1]:
import h5py
import os
import numpy as np
import SimpleITK as sitk
from tqdm import tqdm

os.chdir('/mnt/sda3/yigedabuliu/lkq/brats-unet/')


In [2]:
import os
import argparse

from torch.utils.data import DataLoader
import torch
import torch.optim as optim
from tqdm import tqdm
from BraTS import *


In [46]:
import torch.nn as nn
import torch.nn.functional as F


class InConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(InConv, self).__init__()
        self.conv = DoubleConv(in_ch, out_ch)

    def forward(self, x):
        x = self.conv(x)
        return x


class Down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Down, self).__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool3d(2, 2),
            DoubleConv(in_ch, out_ch)
        )

    def forward(self, x):
        x = self.mpconv(x)
        return x


class OutConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(OutConv, self).__init__()
        self.conv = nn.Conv3d(in_ch, out_ch, 1)
        # self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.conv(x)
        # x = self.sigmoid(x)
        return x


class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_ch, out_ch, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm3d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_ch, out_ch, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm3d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class Up(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch):
        super(Up, self).__init__()
        self.up = nn.ConvTranspose3d(in_ch, in_ch, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_ch + skip_ch, out_ch)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # 将x1以插值的方式变到与x2相同尺寸
        if x1.shape != x2.shape:
            target_size = x2.shape[-3:]
            x1 = F.interpolate(x1, size=target_size, mode='trilinear', align_corners=False)

        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x


class UNet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(UNet, self).__init__()
        features = [32, 64, 128, 256]

        self.inc = InConv(in_channels, features[0])
        self.down1 = Down(features[0], features[1])
        self.down2 = Down(features[1], features[2])
        self.down3 = Down(features[2], features[3])
        self.down4 = Down(features[3], features[3])

        self.up1 = Up(features[3], features[3], features[2])
        self.up2 = Up(features[2], features[2], features[1])
        self.up3 = Up(features[1], features[1], features[0])
        self.up4 = Up(features[0], features[0], features[0])
        self.outc = OutConv(features[0], num_classes)

    def forward(self, x):
        x1 = self.inc(x)
        print(f'after inc shape:{x1.shape}')
        x2 = self.down1(x1)
        print(f'after down1 shape:{x2.shape}')
        x3 = self.down2(x2)
        print(f'after down2 shape:{x3.shape}')
        x4 = self.down3(x3)
        print(f'after down3 shape:{x4.shape}')
        x5 = self.down4(x4)
        print(f'after down4 shape:{x5.shape}')

        x = self.up1(x5, x4)
        print(f'after up1 shape:{x.shape}')

        x = self.up2(x, x3)
        print(f'after up2 shape:{x.shape}')

        x = self.up3(x, x2)
        print(f'after up3 shape:{x.shape}')

        x = self.up4(x, x1)
        print(f'after up4 shape:{x.shape}')

        x = self.outc(x)
        print(f'final shape:{x.shape}')
        return x



In [47]:
model = UNet(in_channels=1, num_classes=5)  # 4个类别包含了背景

In [50]:
from CHAOS import CHAOS_h5_Dataset

data_root = '/mnt/sda3/yigedabuliu/lkq/data/MR/CHAOS/h5_datasets/'
tr_dataset = CHAOS_h5_Dataset(data_root, data_aug=True, type='train', modal='MR_T1DUAL_InPhase')

train_loader = DataLoader(dataset=tr_dataset, batch_size=2, num_workers=12,  # num_worker=4
                          shuffle=True, pin_memory=True)

In [51]:
l = len(tr_dataset)
for i in range(l):
    img, label = tr_dataset[i]
    out = model(img.unsqueeze(dim=0))
    print(f'out shape:{out.shape}')

after inc shape:torch.Size([1, 32, 256, 256, 50])
after down1 shape:torch.Size([1, 64, 128, 128, 25])
after down2 shape:torch.Size([1, 128, 64, 64, 12])
after down3 shape:torch.Size([1, 256, 32, 32, 6])
after down4 shape:torch.Size([1, 256, 16, 16, 3])
x1 and x2 shape:(torch.Size([1, 256, 32, 32, 6]), torch.Size([1, 256, 32, 32, 6]))
after up1 shape:torch.Size([1, 128, 32, 32, 6])
x1 and x2 shape:(torch.Size([1, 128, 64, 64, 12]), torch.Size([1, 128, 64, 64, 12]))
after up2 shape:torch.Size([1, 64, 64, 64, 12])
x1 and x2 shape:(torch.Size([1, 64, 128, 128, 24]), torch.Size([1, 64, 128, 128, 25]))
after up3 shape:torch.Size([1, 32, 128, 128, 25])
x1 and x2 shape:(torch.Size([1, 32, 256, 256, 50]), torch.Size([1, 32, 256, 256, 50]))
after up4 shape:torch.Size([1, 32, 256, 256, 50])
final shape:torch.Size([1, 5, 256, 256, 50])
out shape:torch.Size([1, 5, 256, 256, 50])


KeyboardInterrupt: 

In [11]:
l

20

In [43]:
x1 = torch.empty(1, 1, 256, 256, 30)
x2 = torch.empty(1, 1, 256, 256, 31)

if x1.shape != x2.shape:
    x1 = F.interpolate(x1, size=(256, 256, 31), mode='trilinear', align_corners=False)
    print(f'x1 shape:{x1.shape}')

x1 shape:torch.Size([1, 1, 256, 256, 31])


In [45]:
x1.shape[-3:]

torch.Size([256, 256, 31])

In [41]:
import torch

# 假设你有一个张量 tensor，形状为（B，C，H，W，D）
tensor = torch.randn(1, 2, 256, 256, 30)

# 定义目标的维度 D1
D1 = 31

# 使用 interpolate 函数将张量插值到目标形状（B，C，H，W，D1）
resized_tensor = torch.nn.functional.interpolate(tensor, size=(256, 256, D1), mode='trilinear', align_corners=False)

# 打印调整后的张量形状
print(resized_tensor.shape)


torch.Size([1, 2, 256, 256, 31])
