In [1]:
from datasets import load_dataset

ravis_dataset = load_dataset("ravistech/clinical-trial-llm-cancer-restructure")

print(ravis_dataset)
print(ravis_dataset['train'][0])

DatasetDict({
    train: Dataset({
        features: ['metadata', 'data', 'criteria', '__index_level_0__'],
        num_rows: 31840
    })
    validation: Dataset({
        features: ['metadata', 'data', 'criteria', '__index_level_0__'],
        num_rows: 3980
    })
    test: Dataset({
        features: ['metadata', 'data', 'criteria', '__index_level_0__'],
        num_rows: 3981
    })
})
{'metadata': '{\n"NCT_ID" : "NCT00933777",\n"Brief_Title" : "SORAVE - Sorafenib and Everolimus in Solid Tumors",\n\n"Official_title" : "SORAVE-Sorafenib and Everolimus in Solid Tumors. A Phase I Clinical Trial to Evaluate the Safety of Combined Sorafenib and Everolimus Treatment in Patients With Relapsed Solid Tumors",\n\n"Conditions" : [Unspecified Adult Solid Tumor, Protocol Specific, Non-Small Cell Lung Cancer],\n\n"Interventions" : [Drug: Combination of sorafenib and everolimus],\n\n"Location_Countries" : [Germany],\n\n"Study_Design" : {\n"Study_Type" : "INTERVENTIONAL",\n"Phase" : [PHASE1],\n"P

In [2]:
import re
def fix_invalid_json(input_str):
    ## add double quotes around elements inside square brackets if not already quoted
    fixed_str = re.sub(r'(?<=\[)([^\[\],]+)(?=\])', lambda x: '"' + x.group(0).strip() + '"', input_str)
    
    ## add double quotes around words in Conditions and Interventions
    fixed_str = re.sub(r'(?<=\[)([^\"\]]+?)(?=\])', lambda x: '"' + x.group(0).strip().replace(", ", '", "') + '"', fixed_str)
    
    ## fix key-value pairs inside Interventions
    fixed_str = re.sub(r'"([A-Za-z]+): ([A-Za-z0-9\s]+)"', r'"\1: \2"', fixed_str)
    
    # fix dictionary keys
    fixed_str = re.sub(r'(?<!")(\b[A-Za-z_]+\b)(?=\s*:)', r'"\1"', fixed_str)
    
    return fixed_str

In [3]:
!pip install json_repair
import chromadb
from datasets import load_dataset
import json
import json_repair
from sentence_transformers import SentenceTransformer
from unidecode import unidecode

client = chromadb.PersistentClient(path="./clinical_trials_chroma")
model = SentenceTransformer("malteos/scincl")
collection = client.get_or_create_collection("clinical_trials_studies")

ravis_dataset = load_dataset("ravistech/clinical-trial-llm-cancer-restructure")

def embed_studies_from_dataset(dataset, batch_size=32):
    batch_texts = []       
    batch_metadata = []    
    batch_documents = []   
    batch_ids = []         
    index = 1
    length = len(dataset['train'])
    
    for study in dataset['train']:
            metadata = json_repair.loads(fix_invalid_json(study['metadata']))
            official_title = metadata.get('Official_title', '')
            detailed_description = study.get('data', '')

            # Skip if no valid officialTitle or detailedDescription
            if not official_title or not detailed_description:
                continue
            concatenated_text = unidecode(f"{official_title} [SEP] {detailed_description}")
            batch_texts.append(concatenated_text)
            batch_metadata.append({
                "nctId": metadata.get("NCT_ID", "unknown"),
                "officialTitle": official_title,
                "detailedDescription": detailed_description,
                "jsonMetadata": json.dumps(metadata, ensure_ascii=True)
            })
            batch_documents.append(json.dumps({
                "metadata": metadata,
                "description": study.get('data', ''),
                "criteria": study.get('criteria', '')
                },ensure_ascii=True))
            batch_ids.append(metadata.get("NCT_ID", "unknown"))

            # When batch size is reached, process the batch
            if len(batch_texts) == batch_size:
                process_batch(batch_texts, batch_documents, batch_ids, batch_metadata)
                print(f"Processed {len(batch_texts)} studies. {index}/{length}")
                # Clear the batches
                batch_texts.clear()
                batch_documents.clear()
                batch_metadata.clear()
                batch_ids.clear()
            index += 1

    if batch_texts:
        process_batch(batch_texts, batch_documents, batch_ids, batch_metadata)

def process_batch(texts, documents, ids, metadatas):
    embeddings = model.encode(texts, batch_size=len(texts))
    collection.add(
        embeddings=embeddings,
        documents=documents,
        metadatas=metadatas,
        ids=ids
    )
    print(f"Processed and added batch of {len(texts)} studies.")

# adjust batch_size to fit in your gpu memory
embed_studies_from_dataset(ravis_dataset, batch_size=750)
print("Embedding and storing complete!")





Processed and added batch of 750 studies.
Processed 750 studies. 750/31840
Processed and added batch of 750 studies.
Processed 750 studies. 1500/31840
Processed and added batch of 750 studies.
Processed 750 studies. 2250/31840
Processed and added batch of 750 studies.
Processed 750 studies. 3000/31840
Processed and added batch of 750 studies.
Processed 750 studies. 3750/31840
Processed and added batch of 750 studies.
Processed 750 studies. 4500/31840
Processed and added batch of 750 studies.
Processed 750 studies. 5250/31840
Processed and added batch of 750 studies.
Processed 750 studies. 6000/31840
Processed and added batch of 750 studies.
Processed 750 studies. 6750/31840
Processed and added batch of 750 studies.
Processed 750 studies. 7500/31840
Processed and added batch of 750 studies.
Processed 750 studies. 8250/31840
Processed and added batch of 750 studies.
Processed 750 studies. 9000/31840
Processed and added batch of 750 studies.
Processed 750 studies. 9750/31840
Processed and