In [2]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.applications.efficientnet import preprocess_input


def cgr_kmer_matrix(seq: str, k: int = 6) -> np.ndarray:
    """
    Returns a (2^k x 2^k) CGR-like matrix based on k-mer mapping to 2D grid.
    """
    seq = seq.upper().replace("U", "T")
    valid = set("ACGT")
    seq = "".join([c for c in seq if c in valid])

    size = 2 ** k
    mat = np.zeros((size, size), dtype=np.float32)

    code = {"A": 0, "C": 1, "G": 2, "T": 3}


    for i in range(len(seq) - k + 1):
        kmer = seq[i:i+k]
        v = [code[c] for c in kmer]
        x = 0
        y = 0
        for j, s in enumerate(v):
            b0 = s & 1
            b1 = (s >> 1) & 1
            x = (x << 1) | b0
            y = (y << 1) | b1
        mat[y, x] += 1.0

    if mat.max() > 0:
        mat /= mat.max()

    return mat

def cgr_to_rgb_image(mat: np.ndarray, out_size=(224, 224)) -> np.ndarray:
    """
    Convert CGR matrix to 224x224x3 float image (0..255) for ImageNet-pretrained CNN input.
    """
    img = (mat * 255.0).astype(np.float32)
    img = tf.convert_to_tensor(img)[..., tf.newaxis]
    img = tf.image.resize(img, out_size, method="bilinear")
    img = tf.repeat(img, repeats=3, axis=-1)
    return img.numpy()


def random_dna(length: int, pA: float, pC: float, pG: float, pT: float) -> str:
    probs = np.array([pA, pC, pG, pT], dtype=np.float64)
    probs = probs / probs.sum()
    bases = np.array(list("ACGT"))
    return "".join(np.random.choice(bases, size=length, p=probs))

def synthetic_sample(seq_len=2000):
    """
    Create a (sequence, label) pair.
    label: 0=bacteria, 1=virus, 2=human (demo distributions only)
    """
    label = np.random.randint(0, 3)


    if label == 0:
        seq = random_dna(seq_len, 0.25, 0.25, 0.25, 0.25)
    elif label == 1:
        seq = random_dna(seq_len, 0.32, 0.18, 0.18, 0.32)
    else:
        seq = random_dna(seq_len, 0.24, 0.26, 0.26, 0.24)

    return seq, label

def build_synthetic_dataset(n_samples=300, seq_len=2000, k=6):
    X = []
    y = []
    for _ in range(n_samples):
        seq, label = synthetic_sample(seq_len=seq_len)
        mat = cgr_kmer_matrix(seq, k=k)
        img = cgr_to_rgb_image(mat, out_size=(224,224))
        X.append(img)
        y.append(label)

    X = np.stack(X).astype(np.float32)
    y = np.array(y, dtype=np.int32)
    return X, y
X_train, y_train = build_synthetic_dataset(n_samples=240, seq_len=2000, k=6)
X_val,   y_val   = build_synthetic_dataset(n_samples=60,  seq_len=2000, k=6)

train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(500).batch(16).prefetch(tf.data.AUTOTUNE)
val_ds   = tf.data.Dataset.from_tensor_slices((X_val, y_val)).batch(16).prefetch(tf.data.AUTOTUNE)

NUM_CLASSES = 3


backbone = EfficientNetB0(include_top=False, weights="imagenet", input_shape=(224,224,3))
backbone.trainable = False

inputs = layers.Input(shape=(224,224,3))
x = layers.Lambda(preprocess_input, name="preprocess")(inputs)
x = backbone(x, training=False)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dropout(0.35)(x)
outputs = layers.Dense(NUM_CLASSES, activation="softmax", name="classifier")(x)

model = Model(inputs, outputs, name="CGR_EfficientNetB0_Finetune")


model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-3),
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"]
)

callbacks = [
    tf.keras.callbacks.EarlyStopping(patience=2, restore_best_weights=True),
    tf.keras.callbacks.ReduceLROnPlateau(patience=1, factor=0.2),
]

print("\n=== Stage A: Feature Extraction (Head Only) ===")
model.summary()
model.fit(train_ds, validation_data=val_ds, epochs=5, callbacks=callbacks)


backbone.trainable = True
N_UNFREEZE = 40
for layer in backbone.layers[:-N_UNFREEZE]:
    layer.trainable = False

model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-5),
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"]
)

print("\n=== Stage B: Fine-Tuning (Last Layers Unfrozen) ===")
model.fit(train_ds, validation_data=val_ds, epochs=5, callbacks=callbacks)

model.save("cgr_species_finetuned.keras")
print("\nSaved: cgr_species_finetuned.keras")

Downloading data from https://storage.googleapis.com/keras-applications/efficientnetb0_notop.h5
[1m16705208/16705208[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step

=== Stage A: Feature Extraction (Head Only) ===


Epoch 1/5
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m35s[0m 2s/step - accuracy: 0.3654 - loss: 1.1152 - val_accuracy: 0.6667 - val_loss: 0.9283 - learning_rate: 0.0010
Epoch 2/5
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 1s/step - accuracy: 0.5334 - loss: 0.9870 - val_accuracy: 0.6833 - val_loss: 0.8304 - learning_rate: 0.0010
Epoch 3/5
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m24s[0m 2s/step - accuracy: 0.5378 - loss: 0.9210 - val_accuracy: 0.6667 - val_loss: 0.7651 - learning_rate: 0.0010
Epoch 4/5
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 1s/step - accuracy: 0.6243 - loss: 0.8416 - val_accuracy: 0.7000 - val_loss: 0.7204 - learning_rate: 0.0010
Epoch 5/5
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 1s/step - accuracy: 0.5977 - loss: 0.7938 - val_accuracy: 0.7000 - val_loss: 0.6832 - learning_rate: 0.0010

=== Stage B: Fine-Tuning (Last Layers Unfrozen) ===
Epoch 1/5
[1m15/15[0m [32m━━━