In [340]:
%%writefile logger_config.py

import logging

logger = logging.getLogger()

def setup_logger(run_path):
    log_format = "%(asctime)s - %(levelname)s - %(message)s"

    logger.setLevel(logging.DEBUG)

    file_handler = logging.FileHandler(run_path / "output.log")
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(logging.Formatter(log_format))
    logger.addHandler(file_handler)

    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)
    logger.addHandler(console_handler)

Overwriting logger_config.py


In [341]:
%%writefile experiment_setup_snp.py

import os
from pathlib import Path
import datetime
import logging
import sys

MAIN_DIR = "/gpfs/gibbs/pi/gerstein/tu54/imaging_project/expression-prediction/thyroid"
EXP_NAME = "Thyroid-by-tile-NIC-CNN"

def get_run_folder(args):
    args_str = f"lr{args.learning_rate}-test_size{args.test_size}-batch_size{args.batch_size}-epochs{args.epochs}-column{args.snp_column}"
    now_str = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    return f"run_{now_str}_{args_str}"

def create_path(parent_dir, child_dirs):
    path = parent_dir
    for child_dir in child_dirs:
        path = path / child_dir
        path.mkdir(exist_ok=True)
    return path

def initialize_experiment(args):
    current_path = create_path(Path(MAIN_DIR), ["data", EXP_NAME, get_run_folder(args)])
    return current_path

Overwriting experiment_setup_snp.py


In [344]:
%%writefile data_setup_snp.py

import os
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from tensorflow.keras.utils import to_categorical
from logger_config import logger

def create_datasets(
    csv_file_path: str,
    image_dir_path: str,
    current_path: str,
    barcode_column: str,
    seed: int,
    test_size: float,
    snp_column: int
):
    data = pd.read_csv(csv_file_path)
    data['genotype'] = data[barcode_column].apply(lambda x: 'AA' if x[snp_column] == '0' else 'AC' if x[snp_column] == '1' else 'CC')

    image_file = [i.split("_")[0] for i in os.listdir(image_dir_path)]
    image_filenames_df = pd.DataFrame(image_file, columns=['image_file'])
    image_filenames_df['Tissue Sample ID'] = image_filenames_df['image_file']

    merged_data = pd.merge(image_filenames_df, data, left_on='Tissue Sample ID', right_on='Tissue Sample ID')
    print(len(merged_data))

    shuffled_data = merged_data.sample(frac=1, random_state=seed)
    shuffled_data = shuffled_data.drop(columns=['Tissue Sample ID', 'Unnamed: 0', 'Tissue', 'Patient ID'])

    logger.info("Final dataset (first few rows): \n%s", shuffled_data.head())

    shuffled_data.to_csv(os.path.join(current_path, "data_with_genotype.csv"))

    features = shuffled_data[['image_file', 'barcode']]
    labels = shuffled_data['genotype']

    X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=test_size, random_state=seed)

    label_encoder = LabelEncoder()
    y_train_encoded = label_encoder.fit_transform(y_train)
    y_test_encoded = label_encoder.transform(y_test)

    y_train_onehot = to_categorical(y_train_encoded)
    y_test_onehot = to_categorical(y_test_encoded)

    class_names = ['AA', 'AC', 'CC']
    logger.info("Class names: \n%s", class_names)

    return (X_train, y_train_onehot), (X_test, y_test_onehot), class_names

Overwriting data_setup_snp.py


In [347]:
%%writefile augmentation_snp.py

import os
import numpy as np
import random
from utils import selection
from logger_config import logger

IMG_DIR = "/gpfs/gibbs/pi/gerstein/jrt62/imaging_project/expression-prediction/Thyroid/Thyroid-no-bounding-features-remove-bg"

def is_valid_patch(arr):
    # Check if a patch is valid
    return arr.shape[0] == 64 and arr.shape[1] == 64 and arr.shape[2] == 128 and np.isnan(arr[:,:,1]).sum() < 1000

def augment_patch(arr, y, sum_x, sum_y, is_train_data):
    # If the patch is valid, augment it and append it to sum_x and sum_y
    if is_valid_patch(arr):
        arr1 = np.nan_to_num(arr)
        sum_x.append(arr1)
        sum_y.append(y)
        if is_train_data: 
            num = random.randint(1,7)
            sum_x.append(selection(arr1, num))
            sum_y.append(y)

def process_slide(i, y_label, sum_x, sum_y, is_train_data):
    # Process a whole slide
    a = np.load(i + "_features.npy") # Load the compressed slide
    b = np.swapaxes(a, 0, 2) # Swap axes
    unit_list = []
    
    for j in range(128):
        c = b[:,:,j][~np.isnan(b[:,:,j]).all(axis=1)] # Remove NaN rows
        d = (c.T[~np.isnan(c.T).all(axis=1)]).T # Remove NaN columns
        unit_list.append(d)
        
    e = np.array(unit_list)
    f = np.swapaxes(e, 0, 2)
    
    f_len = f.shape[0]
    f_width = f.shape[1]
    f_len_int = f_len // 32
    f_width_int = f_width // 32
    
    for k in range(f_len_int - 2):
        for p in range(f_width_int - 2):
            patch100 = f[k * 32:(k * 32 + 64), p * 32:(p * 32 + 64),:]
            augment_patch(patch100, y_label, sum_x, sum_y, is_train_data)
            
        patch100 = f[k * 32:(k * 32 + 64), -65:-1,:]
        augment_patch(patch100, y_label, sum_x, sum_y, is_train_data)

    for p in range(f_width_int):
        patch100 = f[-65:-1,p * 32:(p * 32 + 64),:]
        augment_patch(patch100, y_label, sum_x, sum_y, is_train_data)
    patch100 = f[-65:-1,-65:-1,:]
    augment_patch(patch100, y_label, sum_x, sum_y, is_train_data)

def process_all_slides(data, labels, is_train_data):
    sum_x = []
    sum_y = []

    os.chdir(IMG_DIR)
    if is_train_data:
        logger.info("Processing all slides: training")
    else:
        logger.info("Processing all slides: testing")

    for index, _ in enumerate(data.iterrows()):
        # if is_train_data:
        #     print(index)
        # else:
        #     print(_)
        y_label = labels[index] # access label using integer index
        image_file = data.iloc[index]['image_file']
        process_slide(image_file, y_label, sum_x, sum_y, is_train_data)

    logger.info("Finished processing.")
    return sum_x, sum_y

def stack_data(sum_x, sum_y):
    x = np.stack(sum_x)
    y = np.array(sum_y)
    return x, y

Overwriting augmentation_snp.py


In [364]:
%%writefile model_builder_snp.py

from tensorflow.keras import layers, regularizers
import tensorflow as tf

def build_model(output_units: int, input_shape: tuple):
    dropout = 0.3

    # Input layer 
    inputs = layers.Input(shape=input_shape)
    
    # Conv Block 1
    conv1 = layers.Conv2D(filters=64, kernel_size=3, padding='same', activation='relu')(inputs)
    pool1 = layers.MaxPooling2D()(conv1)
    drop1 = layers.Dropout(dropout)(pool1)
    
    # Conv Block 2 
    conv3 = layers.Conv2D(filters=128, kernel_size=3, padding='same', activation='relu')(drop1)
    pool2 = layers.MaxPooling2D()(conv3)
    drop2 = layers.Dropout(dropout)(pool2)

    # Conv Block 3 
    conv5 = layers.Conv2D(filters=256, kernel_size=3, padding='same', activation='relu')(drop2)
    pool3 = layers.MaxPooling2D()(conv5)
    drop3 = layers.Dropout(dropout)(pool3)

    # Dense layers
    flat = layers.Flatten()(drop3) 
    dense1 = layers.Dense(4096, activation='relu')(flat)
    dense2 = layers.Dense(1024, activation='relu')(dense1)
    dense3 = layers.Dense(128, activation='relu')(dense2)
    dense4 = layers.Dense(16, activation='relu')(dense3)
    outputs = layers.Dense(output_units, activation='softmax')(dense4)
    model = tf.keras.Model(inputs, outputs)

    return model

Overwriting model_builder_snp.py


In [379]:
%%writefile save_models_snp.py

import os
from pathlib import Path
import tensorflow as tf
tf.get_logger().setLevel('ERROR')
import pickle
import numpy as np
from logger_config import logger

def save_model(model: tf.keras.Model,
               history: tf.keras.callbacks.History,
               current_path: str,
               target_dir: str,
               model_name: str,
               y_pred: np.ndarray,
               test_label: np.ndarray):
    # Ensure target directory exists
    target_dir_path = Path(current_path) / target_dir
    target_dir_path.mkdir(parents=True, exist_ok=True)
    
    # Create model save path
    assert model_name.endswith(".h5"), "model_name should end with '.h5'"
    model_save_path = target_dir_path / model_name
    
    # Save the model
    # print(f"[INFO] Saving model to: {model_save_path}")
    logger.info("Saving model to: %s", model_save_path)
    model.save(model_save_path)
    
    # Save the training history
    with open(target_dir_path / 'basic_history.pickle', 'wb') as f:
        pickle.dump(history.history, f)
    
    # Save the prediction and true labels
    np.save(target_dir_path / "y_pred.npy", y_pred)
    np.save(target_dir_path / "test_label.npy", test_label)

Overwriting save_models_snp.py


In [381]:
%%writefile plots_snp.py

import numpy as np
from sklearn.metrics import precision_recall_curve, f1_score
from tensorflow.keras.utils import to_categorical
from matplotlib import pyplot as plt

num_s="top250-8-tile-lr0.00001-optimized-r20-020123-multi-less-augumentation-more-shift"

def plot_learning_rate(a=0.00001, steps=20000, rate_decay=0.4):
    y = []
    x = range(steps)
    for i in range(steps):
        y.append(a * pow(rate_decay, i / 10000))

    plt.figure(figsize=(8.5, 8))
    plt.style.use("classic")
    plt.plot(x,y)
    plt.show()

def plot_accuracy(history, current_path):
    plt.figure(figsize=(8.5, 8))
    plt.style.use("classic")
    plt.plot(history.history['accuracy'])
    plt.plot(history.history['val_accuracy'])
    plt.title('model accuracy')
    plt.ylabel('accuracy')
    plt.xlabel('epoch')
    plt.legend(['train', 'val'], loc='upper left')
    plt.savefig(current_path / "accuracy.png", bbox_inches="tight")
    plt.show()

def plot_categorical_crossentropy_loss(history, current_path):
    plt.figure(figsize=(8.5, 8))
    plt.style.use("classic")
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('model loss')
    plt.ylabel('categorical crossentropy')
    plt.xlabel('epoch')
    plt.legend(['train', 'val'], loc='upper left')
    plt.savefig(current_path / "categorical_crossentropy_loss.png", bbox_inches="tight")
    plt.show()

def plot_f1_scores(y_true, y_pred, current_path, labels=None):
    if len(y_true.shape) == 1:  # if y_true is integer encoded
        y_true = to_categorical(y_true)
    if len(y_pred.shape) == 1:  # if y_pred is integer encoded
        y_pred = to_categorical(y_pred)

    f1_scores = f1_score(y_true.argmax(axis=1), y_pred.argmax(axis=1), average=None)
    plt.bar(range(len(f1_scores)), f1_scores)
    plt.xlabel('Class')
    plt.ylabel('F1 Score')
    plt.title('F1 Scores for Each Class')
    if labels:
        plt.xticks(range(len(f1_scores)), labels)
    plt.savefig(current_path / "f1_scores.png", bbox_inches="tight")
    plt.show()

def plot_precision_recall(y_true, y_pred_proba, current_path, labels=None):
    for i in range(y_pred_proba.shape[1]): # loop over each class
        precision, recall, _ = precision_recall_curve(y_true[:, i], y_pred_proba[:, i])
        label = f'Class {i}' if labels is None else labels[i]
        plt.plot(recall, precision, lw=2, label=label)

    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curve')
    plt.legend(loc='best')
    plt.grid()
    plt.savefig(current_path / "precision_recall_curve.png", bbox_inches="tight")
    plt.show()

def all_plots(history, y_pred, test_label, current_path, num_s, y_pred_proba=None):
    plot_learning_rate()
    plt.show()
    
    plot_accuracy(history, current_path)
    
    plot_categorical_crossentropy_loss(history, current_path)
    
    if y_pred_proba is not None:
        plot_f1_scores(test_label, y_pred, current_path)
        plot_precision_recall(test_label, y_pred_proba, current_path)

Overwriting plots_snp.py


In [382]:
%%writefile train_snp.py

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
from tensorflow.keras import optimizers, losses
import data_setup_snp, augmentation_snp, experiment_setup_snp, model_builder_snp, save_models_snp, plots_snp
from tensorflow.keras import layers
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from matplotlib import pyplot as plt
import argparse
from tensorflow.keras.optimizers import Adam
import numpy as np
from sklearn.metrics import roc_curve, average_precision_score, f1_score, precision_recall_curve, auc
from sklearn.utils.class_weight import compute_class_weight
from imblearn.over_sampling import SMOTE
from sklearn.metrics import classification_report
from logger_config import setup_logger, logger

parser = argparse.ArgumentParser(description='Training script')
parser.add_argument('--seed', type=int, default=20)
parser.add_argument('--test_size', type=float, default=0.33)
parser.add_argument('--learning_rate', type=float, default=0.0004)
parser.add_argument('--batch_size', type=int, default=10)
parser.add_argument('--epochs', type=int, default=5)
parser.add_argument('--snp_column', type=int, default=0)
parser.add_argument('--ylabel', type=str, default='/gpfs/gibbs/pi/gerstein/tu54/imaging_project/barcodes_thyroid26.csv')

args = parser.parse_args()

num_s = "top250-8-tile-lr0.00001-optimized-r20-020123-multi-less-augumentation-more-shift"

# Initialize experiment
current_path = experiment_setup_snp.initialize_experiment(args)
setup_logger(current_path)
logger.info(tf.__version__)
logger.info(tf.config.list_physical_devices('GPU'))
logger.info('Experiment setup complete.')

# Prepare training and testing data
(train_data, train_labels), (test_data, test_labels), class_names = data_setup_snp.create_datasets(args.ylabel, 
                                                     '/gpfs/gibbs/pi/gerstein/jrt62/imaging_project/expression-prediction/Thyroid/Thyroid-no-bounding-features-remove-bg',
                                                     current_path,
                                                     'barcode',
                                                     args.seed,
                                                     args.test_size,
                                                     args.snp_column)

sum_x_train, sum_y_train = augmentation_snp.process_all_slides(train_data, train_labels, True)
sum_x_test, sum_y_test = augmentation_snp.process_all_slides(test_data, test_labels, False)

train_image, train_label = augmentation_snp.stack_data(sum_x_train, sum_y_train)
test_image, test_label = augmentation_snp.stack_data(sum_x_test, sum_y_test)
logger.info("Train image shape: %s", train_image.shape)
logger.info("Train label shape: %s", train_label.shape)
logger.info("Test image shape: %s", test_image.shape)
logger.info("Test label shape: %s", test_label.shape)

input_shape = train_image.shape[1:]
output_units = train_label.shape[1]
logger.info("Input shape: %s", input_shape)
logger.info("Output units: %s", output_units)

# checking class distribution before SMOTE
class_distribution_train1 = np.sum(train_label, axis=0)
class_distribution_test1 = np.sum(test_label, axis=0)
logger.info("Training class distribution: %s", class_distribution_train1)
logger.info("Test class distribution: %s", class_distribution_test1)
plt.bar(range(len(class_distribution_train1)), class_distribution_train1)
plt.xlabel('Class')
plt.ylabel('Count')
plt.title('Training Class Distribution')
plt.savefig(current_path / "class_distribution.png", bbox_inches="tight")

# train_image = tf.cast(train_image, tf.float32)

# # smote
# smote = SMOTE(random_state=args.seed)
# train_image, train_label = smote.fit_resample(train_image.numpy().reshape(-1, np.prod(input_shape)), train_label)
# train_image = train_image.reshape(-1, *input_shape)

# # checking class distribution after SMOTE
# class_distribution_train = np.sum(train_label, axis=0)
# class_distribution_test = np.sum(test_label, axis=0)
# # print("Training class distribution:", class_distribution_train)
# # print("Test class distribution:", class_distribution_test)
# logger.info("Training class distribution: %s", class_distribution_train)
# logger.info("Test class distribution: %s", class_distribution_test)
# plt.bar(range(len(class_distribution_train)), class_distribution_train)
# plt.xlabel('Class')
# plt.ylabel('Count')
# plt.title('Training Class Distribution')
# plt.savefig(current_path / "pre_class_distribution.png", bbox_inches="tight")

# converting numpy array to a tensor
train_image = tf.convert_to_tensor(train_image)
train_label = tf.convert_to_tensor(train_label)

# # one-hot encoded training labels
y_train = np.argmax(train_labels, axis=1) # convert to label encoding if needed
class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
class_weight_dict = dict(enumerate(class_weights))
logger.info("Class weight dict: ", class_weight_dict)

early_stopping = EarlyStopping(monitor='val_accuracy', patience=20, restore_best_weights=True)

model = model_builder_snp.build_model(output_units, input_shape)

logger.info(model.summary())

os.chdir(current_path)

optimizer = Adam(learning_rate=args.learning_rate)

model.compile(optimizer=optimizer,
              loss='categorical_crossentropy',
              metrics=['accuracy'])

## attempting data balance with class weights
history = model.fit(train_image, train_label,
                    validation_data=(test_image, test_label),
                    epochs=args.epochs,
                    batch_size=args.batch_size,
                    class_weight=class_weight_dict,
                    callbacks=[early_stopping],
                    verbose=1)

with tf.device('/GPU:0'):
    loss, accuracy = model.evaluate(test_image, test_label)
logger.info("Test loss: %s\nTest accuracy: %s", loss, accuracy)

# Predicting the probabilities
y_pred_proba = model.predict(test_image)
y_pred = np.argmax(y_pred_proba, axis=1)
y_true = np.argmax(test_label, axis=1)

logger.info("\n%s", classification_report(y_true, y_pred, target_names=class_names))

for i in range(output_units):
    precision, recall, _ = precision_recall_curve(test_label[:, i], y_pred_proba[:, i])
    pr_auc = auc(recall, precision)
    logger.info("PR AUC for class %s: %s", i, pr_auc)

save_models_snp.save_model(model=model,
           history=history,
           current_path=current_path,
           target_dir="models",
           model_name="basic.h5",
           y_pred=y_pred_proba,
           test_label=test_label)

logger.info("Making plots")
plots_snp.all_plots(history, y_pred, test_label, current_path, num_s, y_pred_proba)
logger.info("Done.")

Overwriting train_snp.py


In [383]:
!python train_snp.py --seed 20 --test_size 0.33 --learning_rate 0.00002 --batch_size 4 --epochs 100 --snp_column 0 --ylabel '/gpfs/gibbs/pi/gerstein/tu54/imaging_project/barcodes_thyroid26.csv'

2.10.1
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
Experiment setup complete.
734
Final dataset (first few rows): 
          image_file                     barcode genotype
655   GTEX-ZYVF-1126  01210012111112001001121100       AA
213   GTEX-S7PM-0826  20112111122112112021111101       CC
601  GTEX-13N1W-0826  00100002110001200000211202       AA
270   GTEX-S341-0226  22112010011110021121011000       CC
171  GTEX-1EU9M-0626  21102200222122022020121112       CC
Class names: 
['AA', 'AC', 'CC']
Processing all slides: training
Finished processing.
Processing all slides: testing
Finished processing.
Train image shape: (1772, 64, 64, 128)
Train label shape: (1772, 3)
Test image shape: (478, 64, 64, 128)
Test label shape: (478, 3)
Input shape: (64, 64, 128)
Output units: 3
Training class distribution: [752. 704. 316.]
Test class distribution: [188. 205.  85.]
Class weight dict:  {0: 0.8265993265993266, 1: 0.8102310231023102, 2: 1.7985347985347986}
Model: "model"
________