In [1]:
import sys

sys.path.append('../..')

In [2]:
import copy
import os
import random
from glob import glob
from os.path import join

import cv2
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pydicom
import SimpleITK as sitk
import torch
from scipy.ndimage.interpolation import zoom
from torchvision.transforms import Compose
from tqdm import tqdm

from datasets.crc.crc_dataset_3d import CRCDataset3D
from datasets.crc.transforms import crc_transforms_3d as crcT3D
from nets.ae.auto_encoder import AutoEncoder
from nets.seg.res_unet3d import ResUNet3D
from utils.parse_util import format_config, parse_yaml

%matplotlib inline

In [3]:
device = 'cuda:0'

crc_ckpt = torch.load(
    '/root/workspace/DCH_AI/ckpts/crc_ae/v_ae_latent_size.512_amcm.64_leakyrelu/46.pth',
    map_location=device)
crc_net = AutoEncoder(color_channels=1,
                      image_size=np.array([64, 80, 80]),
                      latent_size=512,
                      amcm=64)
crc_net.load_state_dict(crc_ckpt)
crc_net = crc_net.eval().to(device)

In [4]:
trans = Compose([
    crcT3D.Resize([64, 80, 80]),
    crcT3D.Normalize(mean=128, std=128),
    crcT3D.ToTensor()
])

dataset = CRCDataset3D(
    data_root='/root/workspace/DCH_AI/data_crc_3d/',
    sample_csv='/root/workspace/DCH_AI/records/crc/v_crc_labels.csv',
    transforms=trans)

>>> The number of records is 480


In [5]:
f_list = []
with torch.no_grad():
    for idx in range(len(dataset)):
        sample = dataset.__getitem__(idx)
        img = sample['image'].to(device)
        x, z = crc_net(img.unsqueeze(0), is_z=True)
        sample_id = dataset.records.iloc[idx].sample_id
        f_list.append([sample_id] +
                      z.flatten().detach().cpu().numpy().tolist())
        f_pd = pd.DataFrame(f_list)
f_pd.to_excel('autoencoder_v_crc_features_leakyrelu.xlsx', index=False, header=False)