In [1]:
# !pip install -U openmim
# !mim install mmcv==2.0.0

In [2]:
# loading in and transforming data
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader,ConcatDataset
from torch.autograd import Variable

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

#from skimage import io, transform
from PIL import Image

# visualizing data
import matplotlib.pyplot as plt
import numpy as np
import warnings

# load dataset information
import yaml

# image writing
import imageio
from skimage import img_as_ubyte

# Clear GPU cache
torch.cuda.empty_cache()

In [3]:
class LIDC_IDRI_Dataset(Dataset):
    def __init__(self, nodule_path, clean_path, mode, img_size=[128, 128]):

        # nodule_path: path to dataset nodule image folder
        # clean_path: path to dataset clean image folder
        super().__init__()   
        self.nodule_path = nodule_path
        self.clean_path = clean_path
        self.mode = mode
        self.resize = transforms.Resize(img_size)

        # define function to get list of (image, mask)
        self.file_list = self._get_file_list()

        print(len(self.file_list))

    def __len__(self):
        return len(self.file_list)
    
    def _get_file_list(self):
        file_list = []
        for dicom_path in self.nodule_path:
            # Get mask path of nodule image
            mask_path = dicom_path.replace("Image", "Mask")
            mask_path = mask_path.replace("NI", "MA")

            # Check whether mask path exist
            if os.path.exists(mask_path):

                image = np.load(dicom_path)
                # print(image)

                # image = self._normalize_image(image)
                mask = np.load(mask_path)

                # convert image, mask to tensor

                image = torch.from_numpy(image).to(torch.float)
                mask = torch.from_numpy(mask).to(torch.float)

                # add batch dimension 

                image = image.unsqueeze(0)
                mask = mask.unsqueeze(0)
                file_list.append((image, mask))
            
        for dicom_path in self.clean_path:
            # Get mask path of nodule image
            mask_path = dicom_path.replace("Image", "Mask")
            mask_path = mask_path.replace("CN", "CM")

            # Check whether mask path exist

            if os.path.exists(mask_path):

                image = np.load(dicom_path)
                # print(np.max(image))

                # image = self._normalize_image(image)
                mask = np.load(mask_path)

                # convert image, mask to tensor

                image = torch.from_numpy(image).to(torch.float)
                mask = torch.from_numpy(mask).to(torch.float)

                # add batch dimension 

                image = image.unsqueeze(0)
                mask = mask.unsqueeze(0)

                file_list.append((image, mask))

            return file_list

    def __getitem__(self, index):
        image, mask = self.file_list[index]
        return self.resize(image), self.resize(mask)

    def _normalize_image(self, image):
        min_val = np.min(image)
        max_val = np.max(image)

        if max_val - min_val > 0:
            image = (image - min_val) / (max_val - min_val)

        return image

    

In [4]:
def split_data(file_paths, train_val_test_split):
    # get len files
    num_files = len(file_paths)
    # print(num_files)
    
    # ratio
    train_ratio, val_ratio, test_ratio = train_val_test_split
    
    # get num train, val, test
    num_train = int(num_files * train_ratio / (train_ratio + val_ratio + test_ratio))
    num_val = int(num_files * val_ratio / (train_ratio + val_ratio + test_ratio))
    
    # get random index
    train_paths = list(np.random.choice(file_paths, num_train, replace=False))
    val_paths = list(np.random.choice(list(set(file_paths) - set(train_paths)), num_val, replace=False))
    test_paths = list(set(file_paths) - set(train_paths) - set(val_paths))
    return train_paths, val_paths, test_paths
        

In [5]:
nodule_dir = "/data/thanhdd/Lang/Lung-Segmentation/data/LIDC-IDRI-Preprocessing/data/Image"
clean_dir = "/data/thanhdd/Lang/Lung-Segmentation/data/LIDC-IDRI-Preprocessing/data/Clean/Image"
num_nodule = 5000
num_clean = 5000
train_val_test_split = (3,1,1)
img_size=[352,352]
model_type = 'B4'
batch_size = 8
_model_name = 'ESFP_{}_Endo_{}'.format(model_type,"LIDC_IDRI_Dataset")
repeats = 1
n_epochs = 200
file_nodule_list = []
file_clean_list = []

In [6]:
for root, _, files in os.walk(nodule_dir):
    for file in files:
        if file.endswith(".npy"):
            dicom_path = os.path.join(root, file)
            file_nodule_list.append(dicom_path)

# get full path of each clean file
for root, _, files in os.walk(clean_dir):
    for file in files:
        if file.endswith(".npy"):
            dicom_path = os.path.join(root, file)
            file_clean_list.append(dicom_path)

In [7]:
file_nodule_list = file_nodule_list[:num_nodule]

file_clean_list = file_clean_list[:num_clean]

nodule_train, nodule_val, nodule_test = split_data(file_nodule_list, train_val_test_split)

clean_train, clean_val, clean_test = split_data(file_clean_list, train_val_test_split)

In [8]:
data_train = LIDC_IDRI_Dataset(nodule_train, clean_train, mode="train", img_size=img_size)

data_val = LIDC_IDRI_Dataset(nodule_val, clean_val, mode="valid", img_size=img_size)

data_test = LIDC_IDRI_Dataset(nodule_test, clean_test, mode="test", img_size=img_size)

3001
1001
1001


In [9]:
train_dataloader = DataLoader(dataset=data_train,batch_size=batch_size,shuffle=True)
val_dataloader = DataLoader(dataset=data_val,batch_size=8,shuffle=False)
test_dataloader = DataLoader(dataset=data_test,batch_size=1,shuffle=False)


In [10]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [11]:
import sys
sys.path.append('/data/thanhdd/Lang/Lung-Segmentation/src/models/components')
from Encoder import mit
from Decoder import mlp
from mmcv.cnn import ConvModule

class ESFPNetStructure(nn.Module):

    def __init__(self, embedding_dim = 160):
        super(ESFPNetStructure, self).__init__()

        # Backbone
        if model_type == 'B0':
            self.backbone = mit.mit_b0()
        if model_type == 'B1':
            self.backbone = mit.mit_b1()
        if model_type == 'B2':
            self.backbone = mit.mit_b2()
        if model_type == 'B3':
            self.backbone = mit.mit_b3()
        if model_type == 'B4':
            self.backbone = mit.mit_b4()
        if model_type == 'B5':
            self.backbone = mit.mit_b5()

        # self._init_weights()  # load pretrain

        # LP Header
        self.LP_1 = mlp.LP(input_dim = self.backbone.embed_dims[0], embed_dim = self.backbone.embed_dims[0])
        self.LP_2 = mlp.LP(input_dim = self.backbone.embed_dims[1], embed_dim = self.backbone.embed_dims[1])
        self.LP_3 = mlp.LP(input_dim = self.backbone.embed_dims[2], embed_dim = self.backbone.embed_dims[2])
        self.LP_4 = mlp.LP(input_dim = self.backbone.embed_dims[3], embed_dim = self.backbone.embed_dims[3])

        # Linear Fuse
        self.linear_fuse34 = ConvModule(in_channels=(self.backbone.embed_dims[2] + self.backbone.embed_dims[3]), out_channels=self.backbone.embed_dims[2], kernel_size=1,norm_cfg=dict(type='BN', requires_grad=True))
        self.linear_fuse23 = ConvModule(in_channels=(self.backbone.embed_dims[1] + self.backbone.embed_dims[2]), out_channels=self.backbone.embed_dims[1], kernel_size=1,norm_cfg=dict(type='BN', requires_grad=True))
        self.linear_fuse12 = ConvModule(in_channels=(self.backbone.embed_dims[0] + self.backbone.embed_dims[1]), out_channels=self.backbone.embed_dims[0], kernel_size=1,norm_cfg=dict(type='BN', requires_grad=True))

        # Fused LP Header
        self.LP_12 = mlp.LP(input_dim = self.backbone.embed_dims[0], embed_dim = self.backbone.embed_dims[0])
        self.LP_23 = mlp.LP(input_dim = self.backbone.embed_dims[1], embed_dim = self.backbone.embed_dims[1])
        self.LP_34 = mlp.LP(input_dim = self.backbone.embed_dims[2], embed_dim = self.backbone.embed_dims[2])

        # Final Linear Prediction
        self.linear_pred = nn.Conv2d((self.backbone.embed_dims[0] + self.backbone.embed_dims[1] + self.backbone.embed_dims[2] + self.backbone.embed_dims[3]), 1, kernel_size=1)

    def _init_weights(self):

        if model_type == 'B0':
            pretrained_dict = torch.load('/data/thanhdd/Lang/Lung-Segmentation/src/models/components/Pretrained/mit_b0.pth')
        if model_type == 'B1':
            pretrained_dict = torch.load('/data/thanhdd/Lang/Lung-Segmentation/src/models/components/Pretrained/mit_b1.pth')
        if model_type == 'B2':
            pretrained_dict = torch.load('/data/thanhdd/Lang/Lung-Segmentation/src/models/components/Pretrained/mit_b2.pth')
        if model_type == 'B3':
            pretrained_dict = torch.load('/data/thanhdd/Lang/Lung-Segmentation/src/models/components/Pretrained/mit_b3.pth')
        if model_type == 'B4':
            pretrained_dict = torch.load('/data/thanhdd/Lang/Lung-Segmentation/src/models/components/Pretrained/mit_b4.pth')
        if model_type == 'B5':
            pretrained_dict = torch.load('/data/thanhdd/Lang/Lung-Segmentation/src/models/components/Pretrained/mit_b5.pth')


        model_dict = self.backbone.state_dict()
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        model_dict.update(pretrained_dict)
        self.backbone.load_state_dict(model_dict)
        print("successfully loaded!!!!")


    def forward(self, x):

        ##################  Go through backbone ###################

        B = x.shape[0]

        #stage 1
        out_1, H, W = self.backbone.patch_embed1(x)
        for i, blk in enumerate(self.backbone.block1):
            out_1 = blk(out_1, H, W)
        out_1 = self.backbone.norm1(out_1)
        out_1 = out_1.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()  #(Batch_Size, self.backbone.embed_dims[0], 88, 88)

        # stage 2
        out_2, H, W = self.backbone.patch_embed2(out_1)
        for i, blk in enumerate(self.backbone.block2):
            out_2 = blk(out_2, H, W)
        out_2 = self.backbone.norm2(out_2)
        out_2 = out_2.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()  #(Batch_Size, self.backbone.embed_dims[1], 44, 44)

        # stage 3
        out_3, H, W = self.backbone.patch_embed3(out_2)
        for i, blk in enumerate(self.backbone.block3):
            out_3 = blk(out_3, H, W)
        out_3 = self.backbone.norm3(out_3)
        out_3 = out_3.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()  #(Batch_Size, self.backbone.embed_dims[2], 22, 22)

        # stage 4
        out_4, H, W = self.backbone.patch_embed4(out_3)
        for i, blk in enumerate(self.backbone.block4):
            out_4 = blk(out_4, H, W)
        out_4 = self.backbone.norm4(out_4)
        out_4 = out_4.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()  #(Batch_Size, self.backbone.embed_dims[3], 11, 11)

        # go through LP Header
        lp_1 = self.LP_1(out_1)
        lp_2 = self.LP_2(out_2)
        lp_3 = self.LP_3(out_3)
        lp_4 = self.LP_4(out_4)

        # linear fuse and go pass LP Header
        lp_34 = self.LP_34(self.linear_fuse34(torch.cat([lp_3, F.interpolate(lp_4,scale_factor=2,mode='bilinear', align_corners=False)], dim=1)))
        lp_23 = self.LP_23(self.linear_fuse23(torch.cat([lp_2, F.interpolate(lp_34,scale_factor=2,mode='bilinear', align_corners=False)], dim=1)))
        lp_12 = self.LP_12(self.linear_fuse12(torch.cat([lp_1, F.interpolate(lp_23,scale_factor=2,mode='bilinear', align_corners=False)], dim=1)))

        # get the final output
        lp4_resized = F.interpolate(lp_4,scale_factor=8,mode='bilinear', align_corners=False)
        lp3_resized = F.interpolate(lp_34,scale_factor=4,mode='bilinear', align_corners=False)
        lp2_resized = F.interpolate(lp_23,scale_factor=2,mode='bilinear', align_corners=False)
        lp1_resized = lp_12

        out = self.linear_pred(torch.cat([lp1_resized, lp2_resized, lp3_resized, lp4_resized], dim=1))

        return out

  from .autonotebook import tqdm as notebook_tqdm


In [12]:
# def ange_structure_loss(pred, mask, smooth=1):

#     weit = 1 + 5*torch.abs(F.avg_pool2d(mask, kernel_size=15, stride=1, padding=7) - mask)
#     wbce = F.binary_cross_entropy_with_logits(pred, mask, reduction='mean')
#     wbce = (weit*wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3))

#     pred = torch.sigmoid(pred)
#     inter = ((pred * mask)*weit).sum(dim=(2, 3))
#     union = ((pred + mask)*weit).sum(dim=(2, 3))
#     wiou = 1 - (inter + smooth)/(union - inter + smooth)

#     return (wbce + wiou).mean()

def ange_structure_loss(pred, mask, smooth=1):

    weit = 1 + 5*torch.abs(F.avg_pool2d(mask, kernel_size=15, stride=1, padding=7) - mask)
    wbce = F.binary_cross_entropy_with_logits(pred, mask, reduction='mean')
    wbce = (weit*wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3))

    pred = torch.sigmoid(pred)
    inter = ((pred * mask)*weit).sum(dim=(2, 3))
    union = ((pred + mask)*weit).sum(dim=(2, 3))
    wiou = 1 - (inter + smooth)/(union - inter + smooth)

    threshold = torch.tensor([0.5]).to(device)
    pred = (pred > threshold).float() * 1

    pred = pred.data.cpu().numpy().squeeze()
    pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
    # print(type(pred))
    return (wbce + wiou).mean(), torch.from_numpy(pred)

def dice_loss_coff(pred, target, smooth = 0.0001):

    num = target.size(0)
    pred = pred.contiguous()
    target = target.contiguous()

    intersection = (pred * target).sum(dim=2).sum(dim=2)
    loss = (2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)

    return loss.sum()/num

In [13]:
from torch.autograd import Variable

def evaluate():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    ESFPNet.eval()

    val = 0
    count = 0

    smooth = 1e-4

    iter_val = iter(val_dataloader)
    for i in range(len(val_dataloader)):
        # image, gt, name = val_dataloader.load_data()
        image, gt = next(iter_val)
        # gt = gt.squeeze(0)
        # gt = gt.squeeze(0)
        # print(image)
        gt = np.asarray(gt, np.float32)
        gt /= (gt.max() + 1e-8)

        image = image.cuda()

        pred= ESFPNet(image)
        # print('Evalutate: ')
        # print("image: ", image.shape)
        # print('before')
        # print(pred.shape, gt.shape)
        pred = F.upsample(pred, size=img_size, mode='bilinear', align_corners=False)
        # print('after')
        # print(pred.shape, gt.shape)
        pred = pred.sigmoid()
        threshold = torch.tensor([0.5]).to(device)
        pred = (pred > threshold).float() * 1

        pred = pred.data.cpu().numpy().squeeze()
        pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)

        target = np.array(gt)

        input_flat = np.reshape(pred,(-1))
        target_flat = np.reshape(target,(-1))

        intersection = (input_flat*target_flat)

        loss =  (2 * intersection.sum() + smooth) / (pred.sum() + target.sum() + smooth)

        a =  '{:.4f}'.format(loss)
        a = float(a)

        val = val + a
        count = count + 1

    ESFPNet.train()

    return val/count

In [14]:
# train the network
def training_loop(n_epochs, ESFPNet_optimizer, numIters):

    # keep track of losses over time
    losses = []
    coeff_max = 0

    # set up data and then train
    iter_X = iter(train_dataloader)
    steps_per_epoch = len(iter_X)
    num_epoch = 0
    total_steps = (n_epochs+1)*steps_per_epoch

    # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    for step in range(1, total_steps):

        # Reset iterators for each epoch
        if step % steps_per_epoch == 0:
            iter_X = iter(train_dataloader)
            num_epoch = num_epoch + 1

        # make sure to scale to a range -1 to 1
        images, masks = next(iter_X)
        # print("Train:")
        # print(images.shape, masks.shape)
        # move images to GPU if available (otherwise stay on CPU)
        images = images.to(device)
        masks = masks.to(device)


        # ============================================
        #            TRAIN THE NETWORKS
        # ============================================

        ESFPNet_optimizer.zero_grad()

        # 1. Compute the losses from the network

        out = ESFPNet(images)
        out = F.interpolate(out, scale_factor=4, mode='bilinear', align_corners=False)

        loss, pred = ange_structure_loss(out, masks)

        loss.backward()
        ESFPNet_optimizer.step()

        # ============================================
        #            TRAIN THE NETWORKS
        # ============================================
        # Print the log info
        if step % steps_per_epoch == 0:

            losses.append(loss.item())
            print('Epoch [{:5d}/{:5d}] | preliminary loss: {:6.6f} '.format(num_epoch, n_epochs, loss.item()))

        if step % steps_per_epoch == 0:

            validation_coeff = evaluate()
            print('Epoch [{:5d}/{:5d}] | validation_coeffient: {:6.6f} '.format(
                    num_epoch, n_epochs, validation_coeff))

            if coeff_max < validation_coeff:
                coeff_max = validation_coeff
                save_model_path = './SaveModel/{}_LA_{:1d}'.format(_model_name,numIters)
                os.makedirs(save_model_path, exist_ok=True)
                print(save_model_path)
                torch.save(ESFPNet, save_model_path + '/ESFPNet.pt')
                print('Save Learning Ability Optimized Model at Epoch [{:5d}/{:5d}]'.format(num_epoch, n_epochs))

    return losses, coeff_max

In [15]:
def saveResult(numIters):

    save_path = './results/{}_LA_{:1d}/{}_Splited/'.format(_model_name,numIters,str("LIDC_IDRI_Dataset"))
    os.makedirs(save_path, exist_ok=True)
    print(save_path)

    model_path =  './SaveModel/{}_LA_{:1d}'.format(_model_name,numIters)
    ESFPNetBest = torch.load(model_path + '/ESFPNet.pt')
    ESFPNetBest.eval()
    iter_test = iter(test_dataloader)
    for i in range(len(test_dataloader)):
        image, gt = next(iter_test)
        gt = np.asarray(gt, np.float32)
        gt /= (gt.max() + 1e-8)
        image = image.cuda()

        pred = ESFPNetBest(image)
        pred = F.upsample(pred, size=img_size, mode='bilinear', align_corners=False)
        pred = pred.sigmoid()
        threshold = torch.tensor([0.5]).to(device)
        pred = (pred > threshold).float() * 1
        pred = pred.data.cpu().numpy().squeeze()
        pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)

        imageio.imwrite(save_path+str(i),img_as_ubyte(pred), 'png')

In [16]:
# import torch.optim as optim

# for i in range(repeats):
#     # Clear GPU cache
#     torch.cuda.empty_cache()

#     ESFPNet = ESFPNetStructure()
#     if torch.cuda.is_available():
#         device = torch.device("cuda:0")
#         ESFPNet.to(device)
#         print('Models moved to GPU.')
#     else:
#         print('Only CPU available.')
#     print('#####################################################################################')

#     # hyperparams for Adam optimizer
#     lr=0.0001 #0.0001

#     ESFPNet_optimizer = optim.AdamW(ESFPNet.parameters(), lr=lr)

#     losses, coeff_max = training_loop(n_epochs, ESFPNet_optimizer, i+1)

#     plt.plot(losses)

#     print('#####################################################################################')
#     print('optimize_m_dice: {:6.6f}'.format(coeff_max))

#     saveResult(i+1)
#     print('#####################################################################################')
#     print('saved the results')
#     print('#####################################################################################')

In [17]:
from matplotlib import pyplot as plt

def saveResult():

    save_path = './results/{}/{}/'.format(_model_name,'LIDC_IDRI_Dataset')
    os.makedirs(save_path, exist_ok=True)

    model_path = '/data/thanhdd/Lang/Lung-Segmentation/SaveModel/ESFP_B4_Endo_LIDC_IDRI_Dataset_LA_1'
    ESFPNetBest = torch.load(model_path + '/ESFPNet.pt')
    ESFPNetBest.eval()
    iter_test = iter(test_dataloader)
    print(len(test_dataloader))
    for i in range(len(test_dataloader)):
        image, gt = next(iter_test)
        gt = np.asarray(gt, np.float32)
        gt /= (gt.max() + 1e-8)
        image = image.cuda()

        pred = ESFPNetBest(image)
        # print(pred)
        pred = F.upsample(pred, size=img_size, mode='bilinear', align_corners=False)
        pred = pred.sigmoid()
        threshold = torch.tensor([0.5]).to(device)
        pred = (pred > threshold).float() * 1
        pred = pred.data.cpu().numpy().squeeze()
        pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
        
        print(pred.shape)
        # imageio.imwrite(save_path+str(i),img_as_ubyte(pred))
        # plt.imshow(pred, cmap='gray')
        # plt.show()
        # fig, axs = plt.subplots(1, 3, figsize=(6, 12))
        # axs[0].imshow(image.squeeze(0).permute(1, 2, 0).data.cpu().numpy(),cmap='gray')
        # axs[0].set_title('Image')
        # axs[1].imshow(pred, cmap='gray')
        # axs[1].set_title('pred')
        # # print(gt.shape)
        # axs[2].imshow(np.transpose(gt.squeeze(0),(1,2,0)), cmap='gray')
        # axs[2].set_title('gt')

        # # Adjust layout and spacing
        # plt.tight_layout()

        # # Show the plot
        # plt.show()
        

In [18]:
saveResult()

1001




tensor([[[[ -16.5182,  -15.8488,  -17.3689,  ...,  -15.3340,  -20.5975,
            -25.5738],
          [ -18.1218,  -17.6957,  -22.0633,  ...,  -17.5376,  -31.2101,
            -32.9771],
          [ -18.1412,  -17.5066,  -20.7016,  ...,  -17.8182,  -41.7233,
            -46.0816],
          ...,
          [ -40.9857,  -39.0650,  -31.8214,  ...,  -62.1101, -110.7695,
           -119.5804],
          [ -80.2707,  -70.1386,  -41.6531,  ...,  -60.6190,  -55.7441,
            -44.3265],
          [ -98.9146,  -88.1454,  -44.1178,  ...,  -65.4343,  -34.5705,
            -36.5783]]]], device='cuda:0', grad_fn=<ConvolutionBackward0>)
(352, 352)
tensor([[[[ -17.4177,  -18.2326,  -19.5731,  ...,  -19.3434,  -21.9914,
            -36.4075],
          [ -21.9592,  -17.1106,  -20.2609,  ...,  -17.8308,  -19.6024,
            -28.7001],
          [ -25.1252,  -23.6622,  -19.3466,  ...,  -19.1461,  -22.3391,
            -39.5084],
          ...,
          [ -46.5675,  -45.9184,  -33.7115,  ...,  -



tensor([[[[-16.9064, -18.9542, -20.6192,  ..., -20.4315, -25.5804, -35.1078],
          [-19.0591, -20.2730, -23.9370,  ..., -20.8656, -30.3426, -33.5498],
          [-22.3239, -21.0279, -19.9526,  ..., -20.3438, -28.9310, -42.4011],
          ...,
          [-40.2388, -42.1400, -32.3421,  ..., -65.5168, -78.2284, -89.0202],
          [-79.5551, -65.2055, -51.5240,  ..., -84.5656, -55.9790, -52.6965],
          [-89.3181, -53.9584, -41.1881,  ..., -77.9211, -33.9646, -46.2447]]]],
       device='cuda:0', grad_fn=<ConvolutionBackward0>)
(352, 352)
tensor([[[[ -16.6556,  -19.1554,  -20.0483,  ...,  -23.2964,  -25.7890,
            -28.5868],
          [ -18.5613,  -31.4517,  -31.9636,  ...,  -65.2942,  -38.6165,
            -23.1920],
          [ -20.6923,  -24.7203,  -22.5574,  ...,  -57.2630,  -29.3326,
            -24.4663],
          ...,
          [ -47.3279,  -64.6076,  -45.9216,  ..., -152.6984,  -86.5889,
            -42.5746],
          [ -96.1663,  -97.3675,  -77.1660,  ..., -1