In [1]:
import os
os.chdir('../../')
!pwd

/root/python/myenv/medical-coding-reproducibility-main


In [2]:
best_runs = {
                "CAML": "./experiments/jkz0ogtz",
                "PLMICD": "./experiments/jfywyfod",
                "LAAT": "./experiments/gsycv1tg",
                "MultiResCNN": "./experiments/t1dbhfub",
                "VanillaRNN": "./experiments/mjukdu9c",
                "VanillaConv": "./experiments/vy559dlj",
            }

In [None]:
import pandas as pd
import torch
import src.metrics as metrics
from src.utils.fun_retrieval import prf
from src.utils.seed import set_seed
import csv
set_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the .feather file into a Pandas DataFrame
for model,run in best_runs.items():

    targets = pd.read_feather(best_runs[model] + '/test_targets.feather')
    print('load target finish')
    # Convert the DataFrame back to a PyTorch tensor
    targets = torch.tensor(targets.values).to(device)
  

    retrieve = pd.read_feather(best_runs[model] + '/retrieve.feather')
    retrieve = torch.tensor(retrieve.values).to(device)
    print('load retrieve finish')


    # Load the tensor back
    loaded_tensor = torch.load(best_runs[model] + '/best_model.pt', map_location='cpu', weights_only=True)
    
    db=loaded_tensor['db']
    del loaded_tensor
    print('load db finish')


    number_of_classes=retrieve.shape[1]
    metric_collection = metrics.MetricCollection(
        [   
            metrics.AUC(number_of_classes=number_of_classes, average="micro"),
            metrics.AUC(number_of_classes=number_of_classes, average="macro"),
            metrics.F1Score(
                number_of_classes=number_of_classes, average="micro", threshold=db
            ),
            metrics.F1Score(
                number_of_classes=number_of_classes, average="macro", threshold=db
            ),
            metrics.ExactMatchRatio(number_of_classes=number_of_classes, threshold=db),
            metrics.Precision_K(k=8, number_of_classes=number_of_classes),
            metrics.Precision_K(k=15, number_of_classes=number_of_classes),
            metrics.MeanAveragePrecision(),
            metrics.PrecisionAtRecall(),
            metrics.Precision(
                number_of_classes=number_of_classes, average="micro", threshold=db
            ),
            metrics.Recall(
                number_of_classes=number_of_classes, average="micro", threshold=db
            ),
            metrics.FPR(number_of_classes=targets.shape[1], threshold=db)
        ]
    )

    w_alpha=1
    w_beta=0.1
    w_gramma=0.1
    AvgTopRs=[5,10,20,40]
    AvgLowR=10
    results = []

    for AvgTopR in AvgTopRs:
        
        logits = pd.read_feather(best_runs[model] +'/predictions_test.feather').iloc[:,:-2]
        # Convert the DataFrame back to a PyTorch tensor
        logits = torch.tensor(logits.values).to(device)
        print('load logits finish')

        batch = {"logits": logits, "targets": targets}
        metric_collection.update(batch)
        result = {'model_NavgTop_w_beta':f'{model}_{AvgTopR}_{w_beta}','iteration': 0,'model':model,'NavgTop':AvgTopR,'w_beta':w_beta}
        result.update({key: round(value.item() * 100, 1) for key, value in metric_collection.compute(logits, targets).items()})
        results.append(result)
        metric_collection.reset()




        for i in range(1,11):

            Rocchio = prf(retrieve, logits, AvgTopR=AvgTopR, AvgLowR=AvgLowR, 
                                               w_alpha=w_alpha, w_beta=w_beta, w_gramma=w_gramma, 
                                               chunk_size_b=80000)
            logits = Rocchio
            batch = {"logits": logits, "targets": targets}
            metric_collection.update(batch)
            result = {'model_NavgTop_w_beta':f'{model}_{AvgTopR}_{w_beta}','iteration': i,'model':model,'NavgTop':AvgTopR,'w_beta':w_beta}
            result.update({key: round(value.item() * 100, 1) for key, value in metric_collection.compute(logits, targets).items()})
            results.append(result)

            metric_collection.reset()

                

        # Example dictionary


        # Specify the file name
        filename ='./files/retrieval/retrieval.csv'


        # Determine the mode: 'a' for append, 'w' for write (create/overwrite)
        file_exists = os.path.exists(filename)
        mode = 'a' if file_exists else 'w'

        # Writing to or appending to CSV
        with open(filename, mode=mode, newline="") as file:
            writer = csv.DictWriter(file, fieldnames=results[0].keys())
            
            # Write header only if the file does not exist
            if not file_exists:
                writer.writeheader()
            
            # Write the data rows
            writer.writerows(results)

        print(f"Data {model}_{AvgTopR}_{w_beta} has been {'appended to' if file_exists else 'written to'} {filename}.")
        results = []