In [None]:
import pandas as pd
from transformers import BertTokenizer, BertModel
import torch
from tqdm import tqdm

# Load the CSV data
csv_path = "/kaggle/input/trialscsv/trials.csv"
data = pd.read_csv(csv_path)
print("CSV file loaded successfully.")

# Columns to generate embeddings for
columns_to_embed = ['Study Title', 'Primary Outcome Measures', 'Secondary Outcome Measures', 'criteria']

# Load PubMed BERT tokenizer and model
tokenizer = BertTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
model = BertModel.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
print("PubMed BERT model and tokenizer loaded.")

# Check GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = model.to(device)

# Function to generate embeddings
def generate_embeddings(text, tokenizer, model):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
    inputs = {key: value.to(device) for key, value in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()

# Create a DataFrame to store embeddings
embedding_data = []
nct_numbers = []

# Process the entire dataset
print("Processing all nodes...")
for idx, (index, row) in enumerate(tqdm(data.iterrows(), total=data.shape[0]), start=1):
    combined_text = " ".join([str(row[col]) for col in columns_to_embed if pd.notna(row[col])])
    embedding = generate_embeddings(combined_text, tokenizer, model)
    embedding_data.append(embedding)
    nct_numbers.append(row['NCT Number'])

    # Print progress every 10 rows
    if idx % 10 == 0:
        print(f"Processed {idx} rows...")

# Convert embeddings to DataFrame and preserve NCT Number
embedding_df = pd.DataFrame(embedding_data)
embedding_df.insert(0, 'NCT Number', nct_numbers)

# Save to CSV
output_csv_path = "embeddings_output.csv"
embedding_df.to_csv(output_csv_path, index=False)
print(f"Embeddings saved to {output_csv_path}")
