In [130]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicConv(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
        self.relu = nn.ReLU() if relu else None

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x


class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

class TopTPercentChannelGate(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, percent_t=0.7, pool_types=['avg', 'max']):
        super(TopTPercentChannelGate, self).__init__()
        self.gate_channels = gate_channels
        self.percent_t = percent_t
        self.mlp = nn.Sequential(
            Flatten(),
            nn.Linear(gate_channels, gate_channels // reduction_ratio),
            nn.ReLU(),
            nn.Linear(gate_channels // reduction_ratio, gate_channels)
        )
        self.pool_types = pool_types
        
    def forward(self, x):
        b, c, _, _ = x.size()
        x_flatten = x.view(b, c, -1)
        top_t = int(round(x_flatten.size(2) * self.percent_t))
        
        channel_att_sum = None

        for pool_type in self.pool_types:
            if pool_type == 'avg':
                selected_values, _ = x_flatten.topk(top_t, dim=2)
                pool = selected_values.mean(dim=2, keepdim=True)
            elif pool_type == 'max':
                selected_values, _ = x_flatten.topk(top_t, dim=2)
                pool = selected_values.max(dim=2, keepdim=True)[0]
            else:
                raise ValueError("Invalid pool_type, choose between 'avg' and 'max'")
                
            channel_att_raw = self.mlp(pool)
            
            if channel_att_sum is None:
                channel_att_sum = channel_att_raw
            else:
                channel_att_sum = channel_att_sum + channel_att_raw
                
        scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
        return x * scale
    

class ChannelPool(nn.Module):
    def forward(self, x):
        return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )

class SpatialGate(nn.Module):
    def __init__(self):
        super(SpatialGate, self).__init__()
        kernel_size = 7
        self.compress = ChannelPool()
        self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
        
    def forward(self, x):
        x_compress = self.compress(x)
        x_out = self.spatial(x_compress)
        scale = torch.sigmoid(x_out) # broadcasting
        return x * scale


class Topt_CBAM(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, percent_t=0.7, pool_types=['avg', 'max'], no_spatial=False):
        super(Topt_CBAM, self).__init__()
        self.ChannelGate = TopTPercentChannelGate(gate_channels, reduction_ratio, percent_t=percent_t, pool_types=['avg', 'max'])
        self.no_spatial=no_spatial
        if not no_spatial:
            self.SpatialGate = SpatialGate()
    def forward(self, x):
        x_out = self.ChannelGate(x)
        if not self.no_spatial:
            x_out = self.SpatialGate(x_out)
        return x_out




import torch
import torch.nn as nn
import torch.nn.functional as F


class DoubleConv(nn.Module):
    """(convolution => [BN] => PReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.InstanceNorm2d(mid_channels),
            nn.PReLU(),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.InstanceNorm2d(out_channels),
            nn.PReLU()
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)
    

class top_t_cbam_UNet(nn.Module):
    def __init__(self, n_channels, n_classes,percent_t=0.5, bilinear=False):
        super(top_t_cbam_UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = (DoubleConv(n_channels, 64))
        self.tCbam0 = Topt_CBAM(64,percent_t=percent_t)

        self.down1 = (Down(64, 128))
        self.tCbam1 = Topt_CBAM(128,percent_t=percent_t)

        self.down2 = (Down(128, 256))
        self.tCbam2 = Topt_CBAM(256,percent_t=percent_t)

        self.down3 = (Down(256, 512))
        self.tCbam3 = Topt_CBAM(512,percent_t=percent_t)

        factor = 2 if bilinear else 1
        self.down4 = (Down(512, 1024 // factor))
        self.up1 = (Up(1024, 512 // factor, bilinear))
        self.up2 = (Up(512, 256 // factor, bilinear))
        self.up3 = (Up(256, 128 // factor, bilinear))
        self.up4 = (Up(128, 64, bilinear))
        self.outc = (OutConv(64, n_classes))

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        x = self.up1(x5, self.tCbam3(x4))
        x = self.up2(x, self.tCbam2(x3))
        x = self.up3(x, self.tCbam1(x2))
        x = self.up4(x, self.tCbam0(x1))
        logits = self.outc(x)
        return logits

if __name__ == "__main__":
    sample = torch.randn((4,1,512,512))
    model = top_t_cbam_UNet(n_channels=1,n_classes=1)

    print(model(sample).shape)



torch.Size([4, 1, 512, 512])


In [163]:
import os
import torch
from dataset import tumor_Dataset

model_path = '/workspace/IITP/task_2D/dir_checkpoint_breast_ROI/top_t_cbam_UNet2.pth'

model = top_t_cbam_UNet(1,1,percent_t=0.3).to(device="cuda:0")
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model,device_ids=[0,1,2,4]) 
model.load_state_dict(torch.load(model_path))


<All keys matched successfully>

In [164]:
model.eval()

DataParallel(
  (module): top_t_cbam_UNet(
    (inc): DoubleConv(
      (double_conv): Sequential(
        (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (2): PReLU(num_parameters=1)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (5): PReLU(num_parameters=1)
      )
    )
    (tCbam0): Topt_CBAM(
      (ChannelGate): TopTPercentChannelGate(
        (mlp): Sequential(
          (0): Flatten()
          (1): Linear(in_features=64, out_features=4, bias=True)
          (2): ReLU()
          (3): Linear(in_features=4, out_features=64, bias=True)
        )
      )
      (SpatialGate): SpatialGate(
        (compress): ChannelPool()
        (spatial): BasicConv(
          (conv): Conv2d(2, 1, kernel_size=(7,

In [231]:
image_path = "/mount_folder/input/"

list_path = os.listdir(image_path)

full_image_path = [image_path + f for f in list_path]

full_image_path[0].split("/")[-1]

'1389087_117.dcm'

In [250]:
import pydicom as dcm
from pydicom.pixel_data_handlers.util import apply_voi_lut
import torchvision.transforms as transforms
import random
import numpy as np
import matplotlib.pyplot as plt

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Resize((512,512))])
for f in full_image_path:
    slice = dcm.read_file(f)
    image_ = slice.pixel_array
    image = apply_voi_lut(image_, slice)
    epsilon = 1e-10
    min_val = np.min(image)
    max_val = np.max(image)
    image = (image - min_val) / (max_val - min_val+epsilon)

    image_ = transform(image_.astype(np.float32))
    image = transform(image)

    image = image.to(dtype=torch.float32).unsqueeze(0)
    output = torch.sigmoid(model(image))
    thresh = torch.zeros_like(output)
    thresh[output>0.5] = 1.0
    roi = (output.detach().cpu().numpy()).squeeze().squeeze() * image_.detach().numpy()

    ds = pydicom.Dataset()
    ds.PatientID = slice.PatientID
    ds.Modality = slice.Modality
    ds.SeriesInstanceUID = slice.SeriesInstanceUID
    ds.SOPInstanceUID = slice.SOPInstanceUID
    ds.SOPClassUID = slice.SOPClassUID
    ds.Rows = roi.shape[0]
    ds.Columns = roi.shape[1]
    ds.PixelData = roi.tobytes()
    ds.BitsAllocated = slice.BitsAllocated
    ds.BitsStored = slice.BitsStored
    ds.HighBit = slice.HighBit
    ds.PixelRepresentation = slice.PixelRepresentation
    ds.SamplesPerPixel = slice.SamplesPerPixel
    ds.PhotometricInterpretation = slice.PhotometricInterpretation
    ds.is_implicit_VR=slice.is_implicit_VR
    ds.is_little_endian=slice.is_little_endian
    new_path = "/mount_folder/roi_input/"+f.split("/")[-1]
    ds.save_as(new_path)






In [240]:
roi = (output.detach().cpu().numpy()).squeeze().squeeze() * image_
print(roi.tobytes())

b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x

In [None]:
sample = pydicom.dcmread("/mount_folder/roi_input/1259487_73.dcm")
print(sample.pixel_array.shape)
