In [None]:
# Import python packages
import streamlit as st
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
from pandas.tseries.offsets import *
from nltk.corpus import stopwords
import nltk
from collections import Counter
import re
from colorama import Fore

import snowflake.snowpark as snowpark

In [None]:
def data_import():
    session = snowpark.context.get_active_session()

    query = """SELECT * FROM research.kdolgin.unique_descriptions_keep limit 300000"""
         
    try:
        df = session.sql(query).to_pandas()
        return df
        
    except Exception as e:
        print(f"Error: {e}")
        return pd.DataFrame()
    
    return xf

data = data_import()
print(f"Count of records : {data.count()}")
data.head()

In [None]:
data["COMBINED_TEXT"] = (
    # data["SOURCE_DESCRIPTION"] + " " +
    data["LVL3_TITLE"].fillna('') + " " +
    data["LVL6_GMDN_TT_NAME"].fillna('') + " " +
    data["BRAND_NAME_CLEAN"].fillna('') + " " +
    data["DEVICE_SYNONYM_LIST"].fillna('') + " " +
    data["MANUFACTURER_SUBMITTED_DESCRIPTION_CLEAN"].fillna('') + " " +
    data["UNSPSC_TITLE"].fillna('') + " " +
    data["GMDN_PT_NAME"].fillna('') + " " +
    # data["COMPANY_NAME"].fillna('') + " " +
    data["SKU"].fillna('') + " " +
    # data["UNSPSC_DEFINITION"].fillna('') + " " +
    
    data["BY_NAME_GMDN"].fillna('') + " " +
    data["LONG_DESCRIPTION"].fillna('') + " " +
    data["PRODUCT_DEFINITION"].fillna('') + " " +
    data["DEVICE_DESCRIPTION"].fillna('') + " " +
    data["DEVICE_DESCRIPTION_CLEAN_LONG"].fillna('') + " " +
    # data["DEVICE_TXT"].fillna('') + " " +
    # data["DEVICE_CATEGORY_PATH"].fillna('') + " " +
    # data["DEVICE_ATTRIBUTE_ASSORTMENT_GMDN"].fillna('') + " " +
    # data["DEVICE_DESCRIPTION_254_NO_FLAG"].fillna('') + " " +
    data["DEVICE_CLINICAL_DESCRIPTION_300"].fillna('')
)

# Extract the first part of each description before the first '.'
data['BY_USE_GMDN_CAT'] = data['BY_USE_GMDN'].astype(str).apply(lambda x: x.split('(')[0])
data['BY_USE_GMDN_CAT'] = data['BY_USE_GMDN_CAT'].astype(str).str.strip()
category_map = {
    "Anaesthesia and respiratory devices": "Anaesthesia, Pulmonary, and Respiratory Devices",
    "Pulmonary devices": "Anaesthesia, Pulmonary, and Respiratory Devices",
    "Cardiovascular devices": "Cardiovascular Devices",
    "Gastro-urological devices": "Endoscopic and Gastro-Urological Devices",
    "Endoscopic devices": "Endoscopic and Gastro-Urological Devices",
    "Body fluid and tissue management devices": "General Hospital Devices",
    "Body tissue manipulation and reparation devices": "General Hospital Devices",
    "General hospital devices": "General Hospital Devices",
    "Dermatological and soft-tissue reconstructive/cosmetic devices": "Dermatological, Plastic, and Cosmetic Surgery Devices",
    "Plastic surgery and cosmetic devices": "Dermatological, Plastic, and Cosmetic Surgery Devices",
    "Dental devices": "Dental/Maxillofacial Devices",
    "Dental/Maxillofacial devices": "Dental/Maxillofacial Devices",
    "Complementary therapy devices": "General Hospital Devices",
    "Disability-assistive products": "General Hospital Devices",
    "Physical therapy devices": "General Hospital Devices",
    "Ear/Nose/Throat": "Neurological and ENT Devices",
    "Neurological devices": "Neurological and ENT Devices",
    "In vitro diagnostic medical devices": "Laboratory and Diagnostic Devices",
    "Laboratory instruments and equipment": "Laboratory and Diagnostic Devices",
    "Obstetrical/Gynaecological devices": "Obstetrical and Gynaecological Devices",
    "Ophthalmic devices": "Ophthalmic Devices",
    "Orthopaedic devices": "Orthopaedic Devices",
    "Radiological devices": "Laboratory and Diagnostic Devices",
    "Healthcare facility products and adaptations": "General Hospital Devices"
}
data["HighLevelCategory"] = data["BY_USE_GMDN_CAT"].map(category_map)
data["HighLevelCategory"].unique()
data_agg = data[["HighLevelCategory", "COMBINED_TEXT", "SOURCE_DESCRIPTION"]]


data.head()

In [None]:
import torch
import torch.nn.functional as F
from nltk.corpus import stopwords
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModel
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report
from sklearn.preprocessing import LabelEncoder

In [None]:
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Device: {torch.device('cuda' if torch.cuda.is_available() else 'cpu')}")


# 1) Setup


In [None]:
# nltk.download('stopwords')
#stop_words = set(stopwords.words('english'))

stop_words = ["i","me","my","myself","we", "our", "ours", "ourselves", "you", "your", "yours", "yourself", "yourselves", "he", "him", "his", "himself", "she", "her", "hers", "herself", "it", "its", "itself", "they", "them", "their", "theirs", "themselves", "what", "which", "who", "whom", "this", "that", "these", "those", "am", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had", "having", "do", "does", "did", "doing", "a", "an", "the", "and", "but", "if", "or", "because", "as", "until", "while", "of", "at", "by", "for", "with", "about", "against", "between", "into", "through", "during", "before", "after", "above", "below", "to", "from", "up", "down", "in", "out", "on", "off", "over", "under", "again", "further", "then", "once", "here", "there", "when", "where", "why", "how", "all", "any", "both", "each", "few", "more", "most", "other", "some", "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", "very", "s", "t", "can", "will", "just", "don", "should", "now"]
def custom_clean(text):
    text = str(text).lower()
    text = re.sub(r'\b\d+\b', '', text)  # Remove standalone numbers
    text = re.sub(r'[^a-z0-9\s\-/]', '', text)  # Keep medically relevant chars
    tokens = text.split()
    cleaned_tokens = [t for t in tokens if t not in stop_words and len(t) > 1]
    return ' '.join(cleaned_tokens)
# Apply custom text cleaning
data_agg["SOURCE_DESCRIPTION"] = data_agg["SOURCE_DESCRIPTION"].apply(custom_clean)
data_agg["COMBINED_TEXT"] = data_agg["COMBINED_TEXT"].apply(custom_clean)

# 1.1 Load the model.zip to stage

In [None]:

LIST @RESEARCH.AAGARWAL.model_stage;


# 1.2 Download & extract the model.zip from stage

In [None]:
import zipfile
import shutil
import snowflake.snowpark as snowpark
from snowflake.snowpark.session import Session


# Create a Snowflake session
session = Session.builder.getOrCreate()

# clean our directory
model_dir = "extracted_model"
if os.path.exists(model_dir):
    shutil.rmtree(model_dir)  # Remove if exists to start fresh
os.makedirs(model_dir, exist_ok=True)

# Download to a specific file path
download_path = f"{model_dir}/models.zip"

# Download the file
try:
    session.file.get("@RESEARCH.AAGARWAL.model_stage/models.zip", download_path)
    print(f"File downloaded to {download_path}")
    
    # Verify the file exists and is a file (not a directory)
    if os.path.exists(download_path):
        if os.path.isfile(download_path):
            print(f"File size: {os.path.getsize(download_path)} bytes")
            
            # Extract the zip file
            with zipfile.ZipFile(download_path, 'r') as zip_ref:
                zip_ref.extractall(model_dir)
                print(f"Extraction complete to {model_dir}")
                
                # List extracted contents
                print("Extracted contents:")
                for root, dirs, files in os.walk(model_dir):
                    for file in files:
                        if file != "models.zip":  # Skip the zip file itself
                            print(f"  {os.path.join(root, file)}")
        else:
            print(f"ERROR: {download_path} is a directory, not a file")
    else:
        print(f"ERROR: {download_path} does not exist after download attempt")
except Exception as e:
    print(f"Error during download or extraction: {e}")
    
    # Let's try an alternative approach
    print("\nTrying alternative approach...")
    
    # List stage content for debugging
    result = session.sql("LIST @RESEARCH.AAGARWAL.model_stage/").collect()
    print("Files in stage:")
    for row in result:
        print(f"  {row['name']}")
    
    # Try direct extract if possible
    extract_dir = "models_extract_direct"
    os.makedirs(extract_dir, exist_ok=True)
    try:
        # Try to get individual files if models.zip was already extracted in the stage
        session.file.get("@RESEARCH.AAGARWAL.model_stage/", extract_dir, recursive=True)
        print(f"Files downloaded directly to {extract_dir}")
    except Exception as e2:
        print(f"Alternative approach failed: {e2}")

In [None]:
import glob

# First, let's check what's in the stage
result = session.sql("LIST @RESEARCH.AAGARWAL.model_stage/").collect()
print("Files in stage:")
for row in result:
    print(f"  {row['name']}")

# Try to get all files from the stage directly
print("\nDownloading files from stage...")
try:
    session.file.get("@RESEARCH.AAGARWAL.model_stage/", model_dir)
    print(f"Files downloaded to {model_dir}")
    
    # Check what was downloaded
    print("\nDownloaded contents:")
    for root, dirs, files in os.walk(model_dir):
        level = root.replace(model_dir, '').count(os.sep)
        indent = ' ' * 4 * level
        print(f"{indent}{os.path.basename(root)}/")
        sub_indent = ' ' * 4 * (level + 1)
        for file in files:
            print(f"{sub_indent}{file}")
            
    # Look for ZIP files that might need extraction
    zip_files = glob.glob(f"{model_dir}/**/*.zip", recursive=True)
    for zip_path in zip_files:
        print(f"\nFound zip file: {zip_path}")
        extract_dir = os.path.dirname(zip_path)
        print(f"Extracting to: {extract_dir}")
        try:
            with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                zip_ref.extractall(extract_dir)
            print("Extraction successful")
        except Exception as e:
            print(f"Error extracting {zip_path}: {e}")
            
    # If no zip files found, the contents might already be extracted
    if not zip_files and os.path.isdir(f"{model_dir}/models.zip"):
        print("\nIt seems 'models.zip' is already a directory with extracted contents.")
        # Rename for clarity
        if os.path.exists(f"{model_dir}/models.zip"):
            os.rename(f"{model_dir}/models.zip", f"{model_dir}/models_extracted")
            print("Renamed 'models.zip' directory to 'models_extracted'")
    
    # Check if we have the expected model files
    model_files = glob.glob(f"{model_dir}/**/pytorch_model.bin", recursive=True)
    if model_files:
        print("\nFound model files:")
        for model_file in model_files:
            print(f"  {model_file}")
    else:
        print("\nNo pytorch_model.bin files found in the extracted content.")
        
except Exception as e:
    print(f"Error during download: {e}")

# 2) Load Pretrained Model (Local Path)

In [None]:
######### TO-DO : get the model

model_path = r"extracted_model/models.zip/models"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModel.from_pretrained(model_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# 3) Create Dataset (Contrastive Learning)

In [None]:
class MedicalDataset(Dataset):
    """ Dataset for training embeddings using contrastive learning. """
    def __init__(self, df):
        self.text_pairs = list(zip(df["COMBINED_TEXT"], df["SOURCE_DESCRIPTION"]))
    def __len__(self):
        return len(self.text_pairs)
    def __getitem__(self, idx):
        return self.text_pairs[idx]
dataset = MedicalDataset(data_agg)

# 4) Collate Function (Tokenization)

In [None]:
def collate_fn(batch):
    combined_texts = [b[0] for b in batch]
    source_texts = [b[1] for b in batch]
    enc_combined = tokenizer(combined_texts, padding=True, truncation=True, max_length=256, return_tensors="pt")
    enc_source = tokenizer(source_texts, padding=True, truncation=True, max_length=256, return_tensors="pt")
    return {k: v.to(device) for k, v in enc_combined.items()}, {k: v.to(device) for k, v in enc_source.items()}
train_loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    collate_fn=collate_fn
)

# 5) Mean Pooling (Convert BERT Output to Sentence Embeddings)

In [None]:
def mean_pooling(last_hidden_state, attention_mask):
    """ Mean Pooling for sentence embeddings. """
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size())
    sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask

# 6) Loss Function: Multiple Negatives Ranking Loss

In [None]:
def multiple_negatives_ranking_loss(embeddings_a, embeddings_b):
    """
    - Contrastive loss comparing each row in embeddings_a against all other rows.
    - Diagonal elements are the positive pairs.
    """
    embeddings_a = F.normalize(embeddings_a, p=2, dim=1)
    embeddings_b = F.normalize(embeddings_b, p=2, dim=1)
    scores = torch.matmul(embeddings_a, embeddings_b.T)
    labels = torch.arange(scores.size(0), device=scores.device)
    loss_fct = torch.nn.CrossEntropyLoss()
    return loss_fct(scores, labels)

# 7) Training Loop (Contrastive Learning)

In [None]:
optimizer = AdamW(model.parameters(), lr=2e-5)
epochs = 3
for epoch in range(epochs):
    model.train()
    total_loss = 0.0
    
    for enc_combined, enc_source in train_loader:
        optimizer.zero_grad()
        
        # Forward pass for COMBINED_TEXT
        out_c = model(**enc_combined)
        embeddings_c = mean_pooling(out_c.last_hidden_state, enc_combined["attention_mask"])
        # Forward pass for SOURCE_DESCRIPTION
        out_s = model(**enc_source)
        embeddings_s = mean_pooling(out_s.last_hidden_state, enc_source["attention_mask"])
        # Compute contrastive loss
        loss = multiple_negatives_ranking_loss(embeddings_c, embeddings_s)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}/{epochs}, Loss={total_loss:.4f}")