# Finetuning

Finetune geneformer on immune-c2s cell-type classification task

In [74]:
# /ihome/kyin/niandrew/.conda/envs/tf_gpu
# /ix1/kyin/niandrew/custom_miniconda
# source /ix1/kyin/niandrew/custom_miniconda/bin/activate tf_gpu
!source /ix1/kyin/niandrew/custom_miniconda/bin/activate tf_gpu
!export XLA_FLAGS=--xla_gpu_cuda_data_dir=/ihome/kyin/niandrew/.conda/envs/tf_gpu/lib/
!export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/ihome/kyin/niandrew/.conda/envs/tf_gpu/lib/

# !export CUDNN_PATH=$(dirname $(python -c "import nvidia.cudnn;print(nvidia.cudnn.__file__)"))
# !export LD_LIBRARY_PATH=${CUDNN_PATH}/lib

  pid, fd = os.forkpty()


## Load Packages

In [75]:
from datasets import load_dataset
from tqdm import tqdm
import pickle
import matplotlib.pyplot as plt
import datetime
import pandas as pd

import os
os.environ["KERAS_BACKEND"] = "jax"  # or "tensorflow" or "torch"

import keras_nlp
import tensorflow as tf
import keras
from keras import backend as K

## Parameters

In [76]:
# Preprocessing params.
PRETRAINING_BATCH_SIZE = 12
FINETUNING_BATCH_SIZE = 12
SEQ_LENGTH = 512
MASK_RATE = 0.125
PREDICTIONS_PER_SEQ = 128

# Model params.
NUM_LAYERS = 6
MODEL_DIM = 256
INTERMEDIATE_DIM = 512
NUM_HEADS = 4
DROPOUT = 0.02
NORM_EPSILON = 1e-12
VOCAB_SIZE = 25427

# Training params.
FINETUNING_LEARNING_RATE = 1e-6
FINETUNING_WEIGHT_DECAY = 0.001
FINETUNING_EPOCHS = 3

# Model name
MODEL_NAME = 'geneformer'
FINETUNE_MODEL_NAME = 'geneformer_cell_classifier_binary'

## Load Data

In [77]:
# Load Vocab File
with open('token_dictionary.pkl', 'rb') as file:
    vocab_dict = pickle.load(file)
vocab_list = list(vocab_dict.keys())
vocab_list.append('<unk>')
VOCAB_SIZE = len(vocab_list)

# Load Cell Type Label File
with open('immune-c2s/label_dictionary.pkl', 'rb') as file:
    label_dict = pickle.load(file)
label_list = list(label_dict.keys())
NUM_CLASSES = len(label_dict)
NUM_CLASSES

35

In [78]:
# Finetuning

# Tokenizer
tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(
    vocabulary=vocab_list,
    sequence_length=SEQ_LENGTH,
    lowercase=False,
    oov_token='<unk>',
    split=True
)

def preprocess(inputs, labels):
    return tokenizer(inputs), tf.one_hot(labels, depth=NUM_CLASSES) # tf.one_hot(labels, depth=NUM_CLASSES)

In [79]:
train = tf.data.experimental.CsvDataset('immune-c2s/train.csv', [tf.string, tf.int32], header=False).batch(FINETUNING_BATCH_SIZE)
test = tf.data.experimental.CsvDataset('immune-c2s/test.csv', [tf.string, tf.int32], header=False).batch(FINETUNING_BATCH_SIZE)
val = tf.data.experimental.CsvDataset('immune-c2s/val.csv', [tf.string, tf.int32], header=False).batch(FINETUNING_BATCH_SIZE)

# Pre-compute preprocessed batches on the fly on the CPU.
finetune_ds = train.map(
    preprocess, num_parallel_calls=tf.data.AUTOTUNE
).prefetch(tf.data.AUTOTUNE)

finetune_test_ds = test.map(
    preprocess, num_parallel_calls=tf.data.AUTOTUNE
).prefetch(tf.data.AUTOTUNE)

finetune_val_ds = val.map(
    preprocess, num_parallel_calls=tf.data.AUTOTUNE
).prefetch(tf.data.AUTOTUNE)

print(finetune_val_ds.take(1).get_single_element())

(<tf.Tensor: shape=(12, 512), dtype=int32, numpy=
array([[12172,  1720, 17247, ...,  3874, 10401,  4934],
       [12172,  1720, 16979, ...,  1509,  4347,  8792],
       [17247, 17326, 17905, ...,  3259, 11496, 11868],
       ...,
       [ 1720, 12172, 16596, ...,  1933, 16593,  5120],
       [25426, 17303, 17247, ...,  6319, 25251, 12347],
       [12172, 17200,  3567, ...,  7903,  2298, 11467]], dtype=int32)>, <tf.Tensor: shape=(12, 35), dtype=float32, numpy=
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0

## Finetuning

In [80]:
# Reload the encoder model from disk so we can restart fine-tuning from scratch.
encoder_model = keras.models.load_model("model/" + MODEL_NAME + ".keras", compile=False)

In [81]:
# Freeze
for layer in encoder_model.layers:
    if layer.name.startswith('transformer_encoder_'):
        if (int(layer.name.split('_')[-1]) <= 3):
            layer.trainable = False
    elif layer.name.startswith('transformer_encoder'):
        layer.trainable = False
    elif layer.name == 'token_and_position_embedding':
        layer.trainable = False
    elif layer.name == 'layer_normalization':
        layer.trainable = False

In [82]:
# Take as input the tokenized input.
inputs = keras.Input(shape=(SEQ_LENGTH,), dtype="int32")

# Encode and pool the tokens.
outputs = encoder_model(inputs)

outputs = keras.layers.GlobalAveragePooling1D()(outputs)

# Predict an output label.
outputs = keras.layers.Dense(256, activation="relu")(outputs)
outputs = keras.layers.Dense(NUM_CLASSES, activation="sigmoid")(outputs)

# Define and compile our fine-tuning model.
finetuning_model = keras.Model(inputs, outputs)
finetuning_model.compile(
    loss="categorical_crossentropy",
    optimizer=keras.optimizers.AdamW(learning_rate=FINETUNING_LEARNING_RATE,
                                        weight_decay=FINETUNING_WEIGHT_DECAY),
    metrics=["accuracy"],
    jit_compile=True
)

finetuning_model.summary(expand_nested=True, show_trainable=True)

In [83]:
log_dir = "logs/fine_tune_fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

In [84]:
history = finetuning_model.fit(
    finetune_ds,
    validation_data=finetune_val_ds,
    callbacks=[tensorboard_callback],
    epochs=1,
    shuffle=True
)

# Save this base model for further finetuning
finetuning_model.save(("model/" + FINETUNE_MODEL_NAME + ".keras"))

# Save training history
with open("model/" + FINETUNE_MODEL_NAME + "_history.pkl", 'wb') as file:
    pickle.dump(history, file)

  14582/Unknown [1m121s[0m 8ms/step - accuracy: 0.1030 - loss: 3.5248

2024-05-05 23:22:19.701580: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


  14583/Unknown [1m125s[0m 8ms/step - accuracy: 0.1030 - loss: 3.5248

2024-05-05 23:22:32.013290: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


[1m14583/14583[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m136s[0m 9ms/step - accuracy: 0.1030 - loss: 3.5248 - val_accuracy: 0.1367 - val_loss: 3.4279
