<a href="https://colab.research.google.com/github/sAndreotti/MedicalMeadow/blob/main/ATML_part2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install datasets

In [None]:
from datasets import load_dataset
import matplotlib.pyplot as plt
import numpy as np
from collections import Counter

## Investigate Dataset

In [None]:
ds = load_dataset("medalpaca/medical_meadow_medical_flashcards")
ds = ds['train']
ds

In [None]:
print(ds.features)
print()

print("Instruction:")
print(f"length: {len(ds['instruction'])}")
print(f"example: {ds['instruction'][0]}")
print()

print(f"Input:")
print(f"length: {len(ds['input'])}")
print(f"example: {ds['input'][0]}")
print()

print(f"Output:")
print(f"length: {len(ds['output'])}")
print(f"example: {ds['output'][0]}")
print()

### Some plots about the dataset

In [None]:
# Extract the 'instruction' field
instructions = ds['instruction']

# Count the frequency of each unique instruction
instruction_counts = {instruction: instructions.count(instruction) for instruction in set(instructions)}

# Sort the instructions by frequency
sorted_instructions = sorted(instruction_counts.items(), key=lambda x: x[1], reverse=True)

# Separate the instructions and their counts for plotting
sorted_instruction_names = [item[0] for item in sorted_instructions]
sorted_instruction_counts = [item[1] for item in sorted_instructions]

# Plotting the frequency of instructions
plt.figure(figsize=(10, 5))

bars = plt.barh(sorted_instruction_names, sorted_instruction_counts, color='skyblue', edgecolor='black', linewidth=1.2)
plt.title('Instruction Frequency Distribution')
plt.xlabel('Frequency')
plt.ylabel('Instruction')

# Show the plot
plt.tight_layout()
plt.show()

In [None]:
input_phrases = ds['input']
output_phrases = ds['output']

# Calculate the length of each phrase
input_lengths = [len(phrase) for phrase in input_phrases]
output_lengths = [len(phrase) for phrase in output_phrases]

# Define the bins for the length ranges
max_input = max(input_lengths)
max_output = max(output_lengths)

input_bins = [i * max_input / 10 for i in range(1, 11)]
output_bins = [i * max_output / 10 for i in range(1, 11)]
bin_labels_input = [f'{int(input_bins[i-1])}-{int(input_bins[i])}' for i in range(1, 10)]
bin_labels_output = [f'{int(output_bins[i-1])}-{int(output_bins[i])}' for i in range(1, 10)]

# Bin the lengths into the categories
input_binned = np.digitize(input_lengths, input_bins)  # Categorize based on input lengths
output_binned = np.digitize(output_lengths, output_bins)  # Categorize based on output lengths

# Count how many phrases fall into each bin
input_bin_counts = [sum(input_binned == i) for i in range(1, len(input_bins))]
output_bin_counts = [sum(output_binned == i) for i in range(1, len(output_bins))]

# Plotting the bar charts
plt.figure(figsize=(20, 10))

# Plotting the input phrase lengths
plt.subplot(1, 2, 1)
plt.bar(bin_labels_input, input_bin_counts, color='skyblue', edgecolor='black')
plt.title('Input Phrases Length Distribution')
plt.xlabel('Length Range')
plt.ylabel('Number of Phrases')

# Plotting the output phrase lengths
plt.subplot(1, 2, 2)
plt.bar(bin_labels_output, output_bin_counts, color='skyblue', edgecolor='black')
plt.title('Output Phrases Length Distribution')
plt.xlabel('Length Range')
plt.ylabel('Number of Phrases')

# Show the plots
plt.tight_layout()
plt.show()

## Tokenize

In [None]:
import re

# merge in & out togheter
merged_list = [f"{a} {b}" for a, b in zip(input_phrases, output_phrases)]

# remove newline characters
docs = [re.sub('\n', ' ', doc) for doc in merged_list]
# split sentences
sentences = [re.split('[?!.]\s', doc) for doc in docs]
sentences[:3]

In [None]:
from pandas.core.common import flatten

sentences = list(flatten(sentences))
sentences[:20]

In [None]:
tokenized_sentences = [re.sub('\W', ' ', sentence).lower().split() for sentence in sentences]
# remove sentences that are only 1 word long
tokenized_sentences = [sentence for sentence in tokenized_sentences if len(sentence) > 1]

for sentence in tokenized_sentences[:5]:
    print(sentence)

## Word2Vec

In [None]:
!pip install --upgrade smart_open
!pip install --upgrade gensim

In [None]:
from gensim.models.word2vec import Word2Vec

model = Word2Vec(tokenized_sentences, vector_size=30, min_count=5, window=10)

In [None]:
import random

sample = random.sample(list(model.wv.key_to_index), 500)
word_vectors = model.wv[sample]

### 3D plot with words

In [None]:
!pip install plotly
import plotly.express as px

In [None]:
import numpy as np
from sklearn.manifold import TSNE

tsne = TSNE(n_components=3, n_iter=2000)
tsne_embedding = tsne.fit_transform(word_vectors)

x, y, z = np.transpose(tsne_embedding)

fig = px.scatter_3d(x=x[:200],y=y[:200],z=z[:200],text=sample[:200])
fig.update_traces(marker=dict(size=3,line=dict(width=2)),textfont_size=10)
fig.show()

In [None]:
first_question = ['man', 'woman']
#question = ['rem', 'sleep', 'hallucinations', 'paralysis']

word_vectors = model.wv[first_question+sample]

tsne = TSNE(n_components=3)
tsne_embedding = tsne.fit_transform(word_vectors)

x, y, z = np.transpose(tsne_embedding)

In [None]:
r = (-200,200)
fig = px.scatter_3d(x=x, y=y, z=z, range_x=r, range_y=r, range_z=r, text=first_question + [None] * 500)
fig.update_traces(marker=dict(size=3,line=dict(width=2)),textfont_size=10)
fig.show()

In [None]:
model.wv.most_similar('menopause')

In [None]:
vec = model.wv.get_vector('headache') + (model.wv.get_vector('fever') - model.wv.get_vector('drug'))
model.wv.similar_by_vector(vec)

## Train and evaluate models

In [None]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("decapoda-research/llama-3-70b-instruct-titan-0.1")
model = AutoModelForCausalLM.from_pretrained(
    "decapoda-research/llama-3-70b-instruct-titan-0.1",
    device_map="cuda",
    torch_dtype="auto",
    trust_remote_code=True,
)

In [None]:
trainer = SFTTrainer(
    model=model,
    train_dataset=ds,
    peft_config=peft_params,
    dataset_text_field="text",
    max_seq_length=None,
    tokenizer=tokenizer,
    args=training_params,
    packing=False,
)

## Add voice interactivity

## Potential extensions