In [1]:
#imports
import os
import random
import tensorflow as tf

In [6]:
#prepare data if working on colab
from google.colab import drive

#mount your Google Drive to access the dataset
drive.mount('/content/drive')

#unzip the data
!unzip /content/drive/MyDrive/project/release/classification.zip -d /content/data
!unzip /content/drive/MyDrive/project/release/verify.zip -d /content/verify  # 生成独立验证集

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Archive:  /content/drive/MyDrive/project/release/classification.zip
replace /content/data/stromal.068.186.TCGA-A1-A0SP-DX1_left-11917_top-54733_bottom-54999_right-12199.png? [y]es, [n]o, [A]ll, [N]one, [r]ename: N
Archive:  /content/drive/MyDrive/project/release/verify.zip
  inflating: /content/verify/stromal.000.216.TCGA-A2-A0YE-DX1_left-62248_top-33996_bottom-34283_right-62545.png  
  inflating: /content/verify/stromal.021.208.TCGA-A2-A0CM-DX1_left-20097_top-57359_bottom-57656_right-20389.png  
  inflating: /content/verify/stromal.022.233.TCGA-AQ-A54N-DX1_left-34950_top-26780_bottom-27096_right-35285.png  
  inflating: /content/verify/stromal.026.034.TCGA-A2-A0YM-DX1_left-46788_top-68959_bottom-69233_right-47127.png  
  inflating: /content/verify/stromal.027.198.TCGA-A2-A04T-DX1_left-74704_top-41367_bottom-41637_right-74980.png  
  inflating: /content/verif

In [7]:
#hyperparameters - do not edit

#超参数设定
#size of input images
size = tf.constant([224, 224], tf.int32)

#normalization for input images
normalizer = tf.keras.applications.resnet50.preprocess_input 

#gradient optimizer
optimizer = 'adam' 

#classification loss
#loss = tf.keras.losses.BinaryCrossentropy() 
loss = tf.keras.losses.CategoricalCrossentropy() 
# CategoricalCrossentropy()

#tensorflow random seed - operation-level and global
tf_random_seed = 0
tf.random.set_seed(tf_random_seed)
glorot_initializer=tf.keras.initializers.GlorotNormal(seed=tf_random_seed)

#label mapping - these classes will be mapped to 0-stromal, 1-TIL, etc.
mapping = ['stromal', 'TIL', 'tumor']

In [8]:
#hyperparameters and other settings - editable

#path to images - absolute or relative to where the notebook is run
path = './data/'
pathvrf = './verify/' # 验证集的内容

#batch size (# images / batch) - adjust only to deal with GPU memory availabilty
batch = 64

#number of training epochs
epochs = 10
#performance metric - you can add additional metrics but do not remove the 'macro_auc' metric
# macro-auc matric -- 生成AUC的？
metrics = [tf.keras.metrics.AUC(name='macro_auc', multi_label=True, num_labels=len(mapping))] 

#measure validation performance after every 10 epochs
validation_freq = 1

#True -> use mixed precision to speedup training
mixed_precision = False#True

#True -> keep training dataset in memory to speedup training
train_cache = True 

#True -> keep validation dataset in memory to speedup validation
validation_cache = True

#random seed for python (non-tensorflow) random operations
py_random_seed = 0

In [9]:
#input pipeline - editable

#function for parsing filenames
# 解析文件名的函数
def parse_filename(file):
    """
    Parses the label, patient, and lab from filename string.
    
    Parameters
    ----------
    file: string
        Name of the png file.
    
    Returns
    -------
    label: string
        Name of class. One of 'tumor', 'stromal', 'TIL', or 'unknown'.
    patient: string
        Unique patient identifier within the dataset.
    lab: string
        Unique identifier of lab that produced patient sample.
    """
    
    label = file.split('.')[0]
    patient = file.split('.')[3][0:12]
    lab = file.split('.')[3].split('-')[1]
    return label, patient, lab


#define preprocessing function for image loading, resizing, and normalization
#定义图像加载、调整大小和归一化的预处理函数
# 将png转成三个通道RBG格式
def load(filename):
    """
    Loads png file into three-dimensional tensor.
    
    Parameters
    ----------
    file: string
        Name of the png file.
    
    Returns
    -------
    img: tensor (uint8)
        Image in RGB format.
    """
    
    img = tf.io.read_file(filename)
    img = tf.image.decode_png(img, channels=3)
    return img


#function for generating tf.data.Datasets
#生成数据集（读入数据集）
def dataset(path, load, cells, mapping, labs, size, normalizer, batch, cache):
    """
    Generates a tensorflow dataset given a list of png files
    and labs to select from.
    
    Parameters
    ----------
    path: string
        Path to .png files to include in the dataset.
    load: function
        Function to load images from disk.
    cells: list (dicts)
        A list of dicts with each dict describing the filename, label, patient id
        and lab id of the image.
    mapping: list (strings)
        A list of class names. Order maps classes to integer labels.
    labs: list (strings)
        A list of strings for lab identifiers to include in the dataset.
    size: tensor (int32)
        A tensor containing the height and width to resize the images to.
    normalizer: function
        A function for normalizing input image values for the network.
    batch: int
        Number of images in a gradient update batch.
    cache: bool
        A 'True' value will cache the dataset in memory to improve speed. 
        If out-of-memory errors occur, set to false.
    
    Returns
    -------
    ds: tf.data.Dataset
        A dataset containing the images located in path with lab identifers
        in 'labs'. Dataset will be shuffled each epoch, and batched, cached,
        and prefetched according to inputs.
    """
    
    
    #autotune - only impacts prefetch/speed performance
    AUTOTUNE = tf.data.experimental.AUTOTUNE
    
    #generate lists of png files and corresponding labels
    ds_cells = [cell for cell in cells if cell['lab'] in labs]
    ds_files = [path + cell['file'] for cell in cells if cell['lab'] in labs]
    ds_labels = [mapping.index(cell['label']) for cell in cells if cell['lab'] in labs]
    
    #create tensorflow dataset from png/label pairs
    ds = tf.data.Dataset.from_tensor_slices((ds_files, ds_labels))
    
    #shuffle list
    print(len(ds_labels))
    print('\n')
    print(ds_cells)
    print(ds_files)
    ds = ds.shuffle(len(ds_labels))#, reshuffle_each_iteration=True

    #apply image loading and one-hot encoding operations to dataset
    ds = ds.map(lambda x, y: (load(x),  tf.one_hot(y, len(mapping))))
        
    #apply image resizing to input size 'size'
    ds = ds.map(lambda x, y: (tf.image.resize(x, size), y))

    #batch for gradient updates
    ds = ds.batch(batch)
    
    #apply normalizer to image
    ds = ds.map(lambda x, y: (normalizer(x), y))
    
    #prefetch for speed performance
    ds = ds.prefetch(AUTOTUNE)
    
    #cache for speed performance if memory is available
    if cache:
        ds = ds.cache()
        
    return ds


#set python package 'random' seed
random.seed(py_random_seed)

#generate list of dicts describing filename, label, patient, and lab for each cell
#randomly shuffle cells so we can safely use a subset of the data (for example)
files = os.listdir('./data')  #
filesvrf = os.listdir('./verify')   #

files = [file for file in files if os.path.splitext(file)[1] == '.png']   # 识别png文件
fields = [parse_filename(file) for file in files 
          if os.path.splitext(file)[1] == '.png'] #解析文件名
cells = [{'file': file, 'label': field[0], 'patient': field[1], 'lab': field[2]}
          for (file, field) in zip(files, fields)]

#------- added -------

filesvrf = [filevrf for filevrf in filesvrf if os.path.splitext(filevrf)[1] == '.png']   # 识别png文件
fieldsvrf = [parse_filename(filevrf) for filevrf in filesvrf 
          if os.path.splitext(filevrf)[1] == '.png'] #解析文件名
cellsvrf = [{'file': filevrf, 'label': field[0], 'patient': field[1], 'lab': field[2]}
          for (filevrf, field) in zip(filesvrf, fieldsvrf)]
#------- end -------------

#remove 'unknown' cells
#移除 unknown 文件
cells = [cell for cell in cells if cell['label'] != 'unknown']
cellsvrf = [cellvrf for cellvrf in cellsvrf if cellvrf['label'] != 'unknown']

#pick a lab at random for validation (for example)
# 随机挑选一个lab做验证集
labs = list(set([cell['lab'] for cell in cells]))
labs.sort()

labsvrf = list(set([cellvrf['lab'] for cellvrf in cellsvrf]))
labsvrf.sort()
# 验证和训练集是不同的lab
#validation_labs = [labs[random.randint(0, len(labs)-1)]]
#train_labs = list(set(labs).difference(set(validation_labs)))
validation_labs = list(set(labsvrf))
train_labs = list(set(labs))

#create training, validation tf.data.Dataset
# 生成训练、验证集
#train_ds =  dataset(path, load, cells, mapping, train_labs, size, 
#                    normalizer,  batch, train_cache)
#validation_ds = dataset(pathvry, load, cells, mapping, validation_labs, size, 
#                        normalizer, batch, validation_cache)

train_ds =  dataset(path, load, cells, mapping, train_labs, size, 
                    normalizer,  batch, train_cache)
validation_ds = dataset(pathvrf, load, cellsvrf, mapping, validation_labs, size, 
                        normalizer, batch, validation_cache)

1500


[{'file': 'TIL.077.176.TCGA-A7-A0CE-DX1_left-58948_top-19236_bottom-19533_right-59249.png', 'label': 'TIL', 'patient': 'TCGA-A7-A0CE', 'lab': 'A7'}, {'file': 'stromal.072.188.TCGA-GM-A2DF-DX1_left-50865_top-46574_bottom-46855_right-51154.png', 'label': 'stromal', 'patient': 'TCGA-GM-A2DF', 'lab': 'GM'}, {'file': 'stromal.076.100.TCGA-A7-A6VV-DX1_left-53025_top-37712_bottom-38014_right-53297.png', 'label': 'stromal', 'patient': 'TCGA-A7-A6VV', 'lab': 'A7'}, {'file': 'TIL.080.065.TCGA-D8-A1JG-DX1_left-18744_top-70225_bottom-70493_right-19005.png', 'label': 'TIL', 'patient': 'TCGA-D8-A1JG', 'lab': 'D8'}, {'file': 'TIL.080.263.TCGA-E2-A159-DX1_left-45531_top-32524_bottom-32810_right-45805.png', 'label': 'TIL', 'patient': 'TCGA-E2-A159', 'lab': 'E2'}, {'file': 'stromal.070.212.TCGA-GM-A2DB-DX1_left-52619_top-43841_bottom-44146_right-52890.png', 'label': 'stromal', 'patient': 'TCGA-GM-A2DB', 'lab': 'GM'}, {'file': 'stromal.077.144.TCGA-EW-A1P1-DX1_left-55413_top-40755_bottom-41045_rig

In [10]:
#training - seek permission to edit

#set mixed precision
if mixed_precision:
    tf.keras.mixed_precision.set_global_policy('mixed_float16')

#define model
with tf.device('/device:GPU:0'):
    
    #define base network
    base = tf.keras.applications.ResNet50(
        include_top=False, weights='imagenet', input_tensor=None,
        input_shape=(size[0], size[1], 3), pooling='avg')

    #define base network input and output
    inputs = tf.keras.Input(shape=(size[0], size[1], 3))
    x = base(inputs)
    
    #add dense network
    x = tf.keras.layers.Dense(len(mapping), kernel_initializer=glorot_initializer)(x)
    outputs = tf.keras.layers.Activation('softmax', dtype='float32')(x)
    model = tf.keras.Model(inputs, outputs)

#compile the model with the adam optimizer and cross-entropy loss
model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

#fit and validate every 10 epochs
model.fit(x=train_ds, epochs=epochs, validation_data=validation_ds, 
          validation_freq=validation_freq, verbose=1)

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x7f7e9bd403d0>