In [None]:
import pandas as pd
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
import numpy as np
import logging
import os

"""
After the model is fine-tuned, we use it to generate metadata from the name of the column and its context information(table name and dataset name)

"""
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

model_path = "/Users/raffaelegiancotti/Desktop/projects/clinical_data_transformation/New_model"
tokenizer = T5Tokenizer.from_pretrained(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path)


def generate_column_metadata(column_name, table_name, dataset_name):
    try:
        #input_text_desc = f"Column '{column_name}' from table '{table_name}' from '{dataset_name}' dataset. Extract metadata file:"
        #input_text_abbr = f"If I provide you with the column '{column_name}' from the '{table_name}' table in the '{dataset_name}' dataset, how many possible ways can you abbreviate this term, and how many related abbreviations can you identify that convey the same meaning?"
        input_text_descr = f"Give a description of {column_name} from {table_name} from {dataset_name} dataset."
        inputs_desc = tokenizer(input_text_descr, return_tensors="pt", padding=True, truncation=True, max_length=512)
        #inputs_abbr = tokenizer(input_text_abbr, return_tensors="pt", padding=True, truncation=True, max_length=512)

        with torch.no_grad():
            outputs_desc = model.generate(inputs_desc['input_ids'], max_length=1024)
            #outputs_abbr = model.generate(inputs_abbr['input_ids'], max_length=1024)

        description = tokenizer.decode(outputs_desc[0], skip_special_tokens=True)
        #abbreviations = tokenizer.decode(outputs_abbr[0], skip_special_tokens=True)

        return description#, abbreviations
    except Exception as e:
        print(f"Error generating metadata for column '{column_name}': {e}")
        return None, None

def generate_metadata_for_table(df, table_name, dataset_name, batch_size=8):
    metadata = []
    columns = df.columns.tolist()
    
    metadata.append({
    "Table Name": table_name,
    "Dataset Name": dataset_name,
    "Table description": "A table containing...", #to define
    "Number of Rows": len(df),
    "Number of Columns": len(df.columns)})
    
    for i in range(0, len(columns), batch_size):
        batch_columns = columns[i:i + batch_size]
        batch_inputs_desc = []
        #batch_inputs_abbr = []
        
        for column in batch_columns:
            input_text_desc = f"Give a description of {column} from {table_name} from {dataset_name}"
            #input_text_abbr = f"If I provide you with the column '{column}' from the '{table_name}' table in the '{dataset_name}' dataset, how many possible ways can you abbreviate this term, and how many related abbreviations can you identify that convey the same meaning?"
            batch_inputs_desc.append(input_text_desc)
            #batch_inputs_abbr.append(input_text_abbr)
        
        # Tokenize and process in batches
        inputs_desc = tokenizer(batch_inputs_desc, return_tensors="pt", padding=True, truncation=True, max_length=512)
        #inputs_abbr = tokenizer(batch_inputs_abbr, return_tensors="pt", padding=True, truncation=True, max_length=512)
        
        with torch.no_grad():
            outputs_desc = model.generate(inputs_desc['input_ids'], max_length=1024)
            #outputs_abbr = model.generate(inputs_abbr['input_ids'], max_length=1024)
            
        for idx, column in enumerate(batch_columns):
            description = tokenizer.decode(outputs_desc[idx], skip_special_tokens=True)
            #abbreviations = tokenizer.decode(outputs_abbr[idx], skip_special_tokens=True)
            sample_data = df[column].sample(5)
            data_type = str(df[column].dtype)
            
            metadata.append({
                "Column name": column,
                "Sample data": list(sample_data),
                "Data type": data_type,
                "Column description": description#,
                #"Abbreviations": abbreviations
            })
    
    return pd.DataFrame(metadata)

if __name__ == "__main__":
    file_path = "/Users/raffaelegiancotti/Desktop/projects/clinical_data_transformation/data/northwestern-icu-nwicu-database-0.1.0/data/nw_hosp/emar.csv"
    df = pd.read_csv(file_path)
    table_name = os.path.splitext(os.path.basename(file_path))[0]
    dataset_name = "nwicu"  
    generated_path = "/Users/raffaelegiancotti/Desktop/projects/clinical_data_transformation/data/emar_nwicu_metadata.csv"  
    
    metadata_df = generate_metadata_for_table(df, table_name, dataset_name)
    
    
    metadata_df.to_csv(generated_path, index=False)
    
    print("Metadata generation complete. Saved to", generated_path)


Metadata generation complete. Saved to /Users/raffaelegiancotti/Desktop/projects/clinical_data_transformation/data/emar_nwicu_metadata.csv
