In [1]:
import os
import numpy as np
import pandas as pd
import cv2
import tensorflow as tf
import sklearn as skl
from tensorflow import keras
from tensorflow.keras import layers

In [2]:
def load_snp_csv(snp_file):
        df = pd.read_csv(snp_file, index_col=0)  # GIDs are row indices
        df = df.astype(float)  # Ensure numerical values
        return df

class MultiModalDataGenerator(tf.keras.utils.Sequence):
    def __init__(self, dataframe, image_path, geno_feature_cols, batch_size=32, shuffle=True):
        self.df = dataframe.reset_index(drop=True)
        self.image_path = image_path
        self.geno_feature_cols = geno_feature_cols
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        return int(np.ceil(len(self.df) / self.batch_size))

    def __getitem__(self, index):
        batch_df = self.df.iloc[index * self.batch_size : (index + 1) * self.batch_size]
        
        X_img = []
        X_geno = []
        Y = []
        
        for _, row in batch_df.iterrows():
            # Load image
            img = np.load(os.path.join(self.image_path, row['Filename']))
            X_img.append(img)

            # Genomic features
            geno = row[self.geno_feature_cols].values.astype(np.float32)
            X_geno.append(geno)

            # Yield
            Y.append(row['GRYLD'])

        return [np.array(X_img), np.array(X_geno)], np.array(Y)

    def on_epoch_end(self):
        if self.shuffle:
            self.df = self.df.sample(frac=1).reset_index(drop=True)

In [3]:
geno_df = load_snp_csv("SNPs_phased_reduced.csv")

In [4]:
csv_path = "/scratch/pawsey1157/rtrivedi/dataset/Phenotypes/Images_GIDs_GRYLD.csv"
img_df = pd.read_csv(csv_path, header=None, names=['Filename', 'GID', 'GRYLD'])

In [5]:
# Load data
#geno_df = pd.read_csv('SNPs_phased_reduced.csv')
#img_df = pd.read_csv('Images_GIDs_GRYLD.csv')
merged_df = pd.merge(img_df, geno_df, on='GID')
merged_df = merged_df.dropna(subset=['GRYLD'])

# Define genomic columns
geno_feature_cols = [col for col in geno_df.columns if col != 'GID']
image_path = "/scratch/pawsey1157/rtrivedi/dataset/Phenotypes/stacked_npy/"

# Create generator
batch_size = 32
train_gen = MultiModalDataGenerator(
    dataframe=merged_df,
    image_path=image_path,
    geno_feature_cols=geno_feature_cols,
    batch_size=batch_size,
    shuffle=True
)

In [6]:
from tensorflow.keras.models import load_model, Model
from tensorflow.keras.layers import Input

# Load original
image_model = load_model('CNN.h5', compile=False)

# Input
img_input = Input(shape=image_model.input_shape[1:], name='image_input')
x = img_input

# Rebuild layers with unique names
for i, layer in enumerate(image_model.layers[1:-1]):  # skip input & output
    config = layer.get_config()
    config['name'] = f"img_{config['name']}_{i}"  # make unique
    x = layer.__class__.from_config(config)(x)

img_features = x
image_feature_extractor = Model(inputs=img_input, outputs=img_features)

In [7]:
geno_model = load_model('wheat_height_predictor.h5', compile=False)

geno_input = Input(shape=geno_model.input_shape[1:], name='geno_input')
x = geno_input

for i, layer in enumerate(geno_model.layers[1:-1]):
    config = layer.get_config()
    config['name'] = f"geno_{config['name']}_{i}"
    x = layer.__class__.from_config(config)(x)

geno_features = x
geno_feature_extractor = Model(inputs=geno_input, outputs=geno_features)

In [8]:
from tensorflow.keras import layers, Model

combined = layers.concatenate([image_feature_extractor.output, geno_feature_extractor.output], name='fusion')
x = layers.Dense(128, activation='relu', name='fusion_dense_1')(combined)
x = layers.Dropout(0.3, name='fusion_dropout')(x)
output = layers.Dense(1, name='yield_output')(x)

multimodal_model = Model(
    inputs=[image_feature_extractor.input, geno_feature_extractor.input],
    outputs=output
)
multimodal_model.compile(optimizer='adam', loss='mse', metrics=['mae'])

In [9]:
multimodal_model.fit(train_gen, epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x14754d724be0>

In [None]:
from sklearn.metrics import r2_score

y_pred = multimodal_model.predict(train_gen)

y_true = np.concatenate([y for _, y in train_gen])

r2 = r2_score(y_true, y_pred)
print(f"R² score: {r2:.4f}")

  76/5385 [..............................] - ETA: 1:32:46