# Weight Technique 3: Multimodal Model Training

Multimodal weights are a technique used in deep learning to combine information from multiple sources (modalities) to enhance model performance. In the context of MedLM, this means leveraging data from different types of medical information, such as:
Text: Medical records and patient notes from MIMIC IV dataset

- Feature Extraction: Each modality is processed separately to extract meaningful features. For example: Text: Word embeddings, sentence embeddings, or other text-based features.
- Feature Fusion: The extracted features from different modalities are combined into a single representation. This can be done using techniques like:
1. Concatenation: Simply combining the feature vectors from each modality.
2. Attention Mechanisms: Learning to focus on the most relevant features from each modality.
- Model Training: The combined multimodal features are used to train the MedLM model. The model learns to associate these features with the desired output (e.g., diagnosis, treatment recommendations).


In [1]:
import json
from google.cloud import aiplatform
from google.cloud.aiplatform.gapic.schema import predict
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value
from langchain.llms import OpenAI
from langchain.prompts import PromptTemplate
from google.cloud import bigquery
from langchain_openai import ChatOpenAI
from langchain_experimental.sql import SQLDatabaseChain
from langchain.sql_database import SQLDatabase
from langchain.prompts import PromptTemplate
import warnings
warnings.filterwarnings('ignore')

In [2]:
openai_api_key = openai_api_key

In [34]:
# Initialize BigQuery client
bigquery_client = bigquery.Client()

# Define the Gemini model
llm = ChatOpenAI(openai_api_key=openai_api_key, model="gpt-4-32k")

# Manually define the SQLDatabase
# Assume you have a BigQuery connection string or credentials file
connection_string = "bigquery://us-gcp-ame-con-5b680-sbx-1/mimic_iv_hosp_icu_dataset"

# Create the SQLDatabase instance
db = SQLDatabase.from_uri(connection_string)

# Create the SQLDatabaseChain
chain = SQLDatabaseChain(llm=llm, database=db)

# Define your natural language query
natural_language_query = "List of medication prescribed to patients treated for post op limit 2"

# Generate SQL and execute
sql_query = chain.run(natural_language_query)
print("Generated SQL:", sql_query)

Generated SQL: subject_id	hadm_id	drug
10000032	22595853	Aspirin
10000032	22595853	Ibuprofen

Answer: The medication prescribed to patients treated for post op includes Aspirin and Ibuprofen.


In [9]:
import numpy as np

In [35]:
import tensorflow as tf
from tensorflow.keras.layers import Embedding, LSTM, Dense
from google.cloud import aiplatform
from google.protobuf import json_format
from google.protobuf.json_format import MessageToDict
from google.protobuf.struct_pb2 import Value

client_options = {"api_endpoint": "us-central1-aiplatform.googleapis.com"}

# Initialize client that will be used to create and send requests.

# This client only needs to be created once, and can be reused for multiple requests.

client = aiplatform.gapic.PredictionServiceClient(

    client_options=client_options

)

# Define the MedLM endpoint
endpoint = "projects/us-gcp-ame-con-5b680-sbx-1/locations/us-central1/publishers/google/models/medlm-large"

# Example reference text
reference_text = 'The medication prescribed to patients treated for post op includes Aspirin and Ibuprofen'

# Prepare the instance for prediction
instance_dict = {"content": reference_text}
instance = json_format.ParseDict(instance_dict, Value())
instances = [instance]

# Define parameters for the MedLM model (optional)
parameters_dict = {
     "candidateCount": 1,
     "maxOutputTokens": 500,
     "temperature": 0.2,
     "topP": 0.8,
     "topK": 40
}


parameters = json_format.ParseDict(parameters_dict, Value())
# Send the prediction request to MedLM
response = client.predict(endpoint=endpoint, instances=instances, parameters=parameters)

# Extract the encoded features from the MedLM response
encoded_reference = response.predictions

for prediction in encoded_reference:
    embeddings = prediction.get('content', None)
    # Check if embeddings were found
    if embeddings is not None:
        # Convert the embeddings to a NumPy array
        features = np.array(embeddings)
        print(features)
    else:
        print("Embeddings not found in the response.")


print(type(features))


 Aspirin and ibuprofen are both nonsteroidal anti-inflammatory drugs (NSAIDs) that are commonly used to relieve pain and inflammation. They work by blocking the production of prostaglandins, which are chemicals that are involved in the inflammatory response. Aspirin is a salicylate, while ibuprofen is a propionic acid derivative. Both drugs are available over-the-counter and by prescription. Aspirin is typically used for short-term pain relief, while ibuprofen can be used for both short-term and long-term pain management. In addition to their analgesic and anti-inflammatory effects, aspirin and ibuprofen can also have other effects on the body. For example, aspirin can inhibit platelet aggregation, which is why it is sometimes used to prevent blood clots. Ibuprofen can also have effects on the kidneys and gastrointestinal tract. It is important to follow the dosing instructions on the medication label and to talk to your doctor or pharmacist if you have any questions or concerns about 

In [55]:
import tensorflow as tf
from tensorflow.keras.layers import Embedding, LSTM, Dense
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from sklearn.feature_extraction.text import TfidfVectorizer
import numpy as np

# Example text data
text_data = 'Aspirin and ibuprofen are both nonsteroidal anti-inflammatory drugs (NSAIDs) that are commonly used to relieve pain and inflammation. They work by blocking the production of prostaglandins, which are chemicals that are involved in the inflammatory response. Aspirin is a salicylate, while ibuprofen is a propionic acid derivative. Both drugs are available over-the-counter and by prescription. Aspirin is typically used for short-term pain relief, while ibuprofen can be used for both short-term and long-term pain management. In addition to their analgesic and anti-inflammatory effects, aspirin and ibuprofen can also have other effects on the body. For example, aspirin can inhibit platelet aggregation, which is why it is sometimes used to prevent blood clots. Ibuprofen can also have effects on the kidneys and gastrointestinal tract. It is important to follow the dosing instructions on the medication label and to talk to your doctor or pharmacist if you have any questions or concerns about aspirin or ibuprofen.'

# Convert text data to numerical features using TF-IDF
vectorizer = TfidfVectorizer(max_features=100)  # Limit to top 100 features for simplicity
features = vectorizer.fit_transform([text_data]).toarray().flatten()  # Convert to 1D array

# Define model parameters
max_length = 100  # Define max_length
vocab_size = 5000  # Example vocab size, adjust as needed
embedding_dim = 128  # Define embedding_dim

# Create the model
def create_model(max_length, feature_dim):
    input_text = tf.keras.Input(shape=(max_length,), dtype="int32")
    reference_features = tf.keras.Input(shape=(feature_dim,))

    # Embed the input text
    embedded_text = Embedding(vocab_size, embedding_dim)(input_text)

    # Process the input text using an LSTM
    lstm_output = LSTM(128)(embedded_text)

    # Concatenate the LSTM output with the reference features
    combined_features = tf.keras.layers.concatenate([lstm_output, reference_features])

    # Add a dense layer for summarization
    summary = Dense(max_length, activation="softmax")(combined_features)

    return tf.keras.Model(inputs=[input_text, reference_features], outputs=summary)

# Calculate feature dimension
feature_dim = features.shape[0]

# Create the model
model = create_model(max_length, feature_dim)

# Compile the model
model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])

# Example input text and reference text for training
input_texts = [
    'The medication prescribed to patients treated for post op includes Aspirin and Ibuprofen.',
    'Aspirin and ibuprofen are used to relieve pain and inflammation.'
]
reference_text = 'Aspirin and ibuprofen are both nonsteroidal anti-inflammatory drugs (NSAIDs) that are commonly used to relieve pain and inflammation. They work by blocking the production of prostaglandins, which are chemicals that are involved in the inflammatory response. Aspirin is a salicylate, while ibuprofen is a propionic acid derivative. Both drugs are available over-the-counter and by prescription. Aspirin is typically used for short-term pain relief, while ibuprofen can be used for both short-term and long-term pain management. In addition to their analgesic and anti-inflammatory effects, aspirin and ibuprofen can also have other effects on the body. For example, aspirin can inhibit platelet aggregation, which is why it is sometimes used to prevent blood clots. Ibuprofen can also have effects on the kidneys and gastrointestinal tract. It is important to follow the dosing instructions on the medication label and to talk to your doctor or pharmacist if you have any questions or concerns about aspirin or ibuprofen.'

# Tokenize the texts
tokenizer = Tokenizer(num_words=vocab_size)
tokenizer.fit_on_texts(input_texts + [reference_text])

# Convert texts to sequences
input_sequences = tokenizer.texts_to_sequences(input_texts)
reference_sequence = tokenizer.texts_to_sequences([reference_text])[0]

# Pad the sequences
input_text_data = pad_sequences(input_sequences, maxlen=max_length, padding='post')
reference_features_data = np.tile(features, (len(input_texts), 1))

# Prepare target text data (example)
target_texts = [
    "Aspirin and ibuprofen are NSAIDs used for pain relief.",
    "Both drugs are available over-the-counter and by prescription."
]
target_sequences = tokenizer.texts_to_sequences(target_texts)
target_text_data = pad_sequences(target_sequences, maxlen=max_length, padding='post')  # No one-hot encoding

# Fit the model
model.fit(
    [input_text_data, reference_features_data],
    target_text_data,
    epochs=10,
    batch_size=32,
)

Epoch 1/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.0000e+00 - loss: 1007.7798
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 91ms/step - accuracy: 0.0000e+00 - loss: 997.7004
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 79ms/step - accuracy: 0.0000e+00 - loss: 986.8818
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 83ms/step - accuracy: 0.0000e+00 - loss: 973.7292
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 136ms/step - accuracy: 0.0000e+00 - loss: 955.9148
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 82ms/step - accuracy: 0.0000e+00 - loss: 929.0397
Epoch 7/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 72ms/step - accuracy: 0.0000e+00 - loss: 885.1016
Epoch 8/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 77ms/step - accuracy: 0.0000e+00 - loss: 821.2887
Epoch 9/10
[1m1

<keras.src.callbacks.history.History at 0x7f7274118ee0>

In [None]:
# Query the Note iv data set for test data 

In [65]:
# Initialize BigQuery client
bigquery_client = bigquery.Client()

# Define the Gemini model
llm = ChatOpenAI(openai_api_key=openai_api_key, model="gpt-4-32k")

# Manually define the SQLDatabase
# Assume you have a BigQuery connection string or credentials file
connection_string = "bigquery://us-gcp-ame-con-5b680-sbx-1/mimic_iv_dataset"

# Create the SQLDatabase instance
db = SQLDatabase.from_uri(connection_string)

# Create the SQLDatabaseChain
chain = SQLDatabaseChain(llm=llm, database=db)

# Define your natural language query
natural_language_query = "List of medication in discharge text where patients are treated for post op limit 2"

# Generate SQL and execute
sql_query = chain.run(natural_language_query)
print("Generated SQL:", sql_query)

Generated SQL: The two medications in the discharge texts where patients are treated for post op are CHF and critical aortic stenosis for the first patient and Vioxx for the second patient.


In [66]:
# Example reference text
reference_text = 'The two medications in the discharge texts where patients are treated for post op are CHF and critical aortic stenosis for the first patient and Vioxx for the second patient.'

# Prepare the instance for prediction
instance_dict = {"content": reference_text}
instance = json_format.ParseDict(instance_dict, Value())
instances = [instance]

# Define parameters for the MedLM model (optional)
parameters_dict = {
     "candidateCount": 1,
     "maxOutputTokens": 500,
     "temperature": 0.2,
     "topP": 0.8,
     "topK": 40
}


parameters = json_format.ParseDict(parameters_dict, Value())
# Send the prediction request to MedLM
response = client.predict(endpoint=endpoint, instances=instances, parameters=parameters)

# Extract the encoded features from the MedLM response
encoded_reference = response.predictions

for prediction in encoded_reference:
    embeddings = prediction.get('content', None)
    # Check if embeddings were found
    if embeddings is not None:
        # Convert the embeddings to a NumPy array
        features = np.array(embeddings)
        print(features)
    else:
        print("Embeddings not found in the response.")


print(type(features))


 CHF (congestive heart failure) and critical aortic stenosis are not medications, they are medical conditions. Vioxx is a medication, but it is not used to treat post-operative pain. It is a non-steroidal anti-inflammatory drug (NSAID) that was used to treat arthritis and other chronic pain conditions. However, it was withdrawn from the market in 2004 due to concerns about its safety.
<class 'numpy.ndarray'>


In [74]:
# Evaluate the model on your test data
# Test data
test_input_texts = [
    'The two medications in the discharge texts where patients are treated for post op are CHF and critical aortic stenosis for the first patient and Vioxx for the second patient',
]
test_reference_text = ''' CHF (congestive heart failure) and critical aortic stenosis are not medications, 
they are medical conditions. Vioxx is a medication, but it is not used to treat post-operative pain. 
It is a non-steroidal anti-inflammatory drug (NSAID) that was used to treat arthritis and other chronic pain conditions.
However, it was withdrawn from the market in 2004 due to concerns about its safety.'''

# Tokenize and pad test data
test_input_sequences = tokenizer.texts_to_sequences(test_input_texts)
test_input_text_data = pad_sequences(test_input_sequences, maxlen=max_length, padding='post')
test_reference_features_data = np.tile(features, (len(test_input_texts), 1))

# Tokenize the texts
tokenizer = Tokenizer(num_words=vocab_size)
tokenizer.fit_on_texts(test_input_texts + [test_reference_text])

# Convert texts to sequences
input_sequences = tokenizer.texts_to_sequences(test_input_texts)
reference_sequence = tokenizer.texts_to_sequences([test_reference_text])[0]

# Pad the sequences
test_input_text_data = pad_sequences(input_sequences, maxlen=max_length, padding='post')
test_reference_features_data = np.tile(features, (len(test_input_texts), 1))
# Generate predictions
predictions = model.predict([test_input_text_data, test_reference_features_data])

# Decode predictions
predicted_sequences = np.argmax(predictions, axis=1)
predicted_texts = tokenizer.sequences_to_texts(predicted_sequences)


# Convert strings to integers (assuming they represent numerical values)
flat_args = [int(arg) for arg in flat_args]

# Now call treespec.unflatten
result = treespec.unflatten(map(func, *flat_args))


# Calculate metrics
loss, accuracy = model.evaluate(
    [test_input_text_data, test_reference_features_data])
    #test_target_text_data,


# Print the results
print("Loss:", loss)
print("Accuracy:", accuracy)

# You can also use other metrics like ROUGE score for summarization evaluation


ValueError: Invalid dtype: str12384

In [78]:
print(f"features dtype: {features.dtype}, shape: {features.shape}")


features dtype: <U387, shape: ()


# Conclusion: Improving MedLM with Multimodal Weights

- Here's how multimodal weights can improve the MedLM model:
1. Enhanced Context: By incorporating information from multiple modalities, the model gains a richer understanding of the medical context. This can lead to more accurate predictions.
2. Complementary Information: Different modalities often provide complementary information. For example, a medical scan might reveal a tumor, while a patient's symptoms might provide additional clues about the tumor's characteristics.
3. Robustness: Multimodal models are often more robust to noise or missing data in a single modality.
Example (Conceptual)

If you are training a MedLM model to predict the risk of a patient developing diabetes. You could use:

Text: Patient medical records, including family history, lifestyle factors, and previous diagnoses.

# Recommendations: Challenges and Considerations

Data Availability: Collecting and annotating multimodal data can be challenging and expensive.

Computational Resources: Training multimodal models can require significant computational resources.

Feature Engineering: Carefully selecting and engineering features from each modality is crucial for successful multimodal learning.