In [2]:
import argparse
import ast
import datetime
import gc
import glob
import json
import logging
import math
import os
import random
import re
import shutil
import time
from collections import Counter, defaultdict
from dataclasses import asdict, dataclass, field
from difflib import get_close_matches
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from datasets import DatasetDict, concatenate_datasets, load_from_disk
import datasets
import numpy as np
import torch
import transformers
import yaml

  from .autonotebook import tqdm as notebook_tqdm


In [6]:

def load_image_datasets(data_paths):
    dataset_interpret = load_from_disk(data_paths["interpret"])
    dataset_mimic = load_from_disk(data_paths["mimic"])

    # Concat both
    dataset_train_dev = DatasetDict({"train": concatenate_datasets([dataset_interpret["train"], dataset_mimic["train"]]), "validation": concatenate_datasets([dataset_interpret["validation"], dataset_mimic["validation"]])})

    dataset_test = load_from_disk(data_paths["interpret-test-public"])

    ds_img = DatasetDict({"train": dataset_train_dev["train"], "validation": dataset_train_dev["validation"], "test": dataset_test["test"]})
    return ds_img

def merge_dataset(img_dataset, graph_dataset):
    imgId_2_graphRowIdx = {}
    for graph_row_idx, doc_key in enumerate(graph_dataset["doc_key"]):
        _, img_id, _ = doc_key.split("#")  # doc_key = test#2250#findings
        imgId_2_graphRowIdx[int(img_id)] = int(graph_row_idx)

    # 如果传入的是 select 后的 img_ds 数据集，那么 img_id 与 img_row_idx 不一定是一一对应的
    # data_key: test#89
    imgId_2_imgRowIdx = {}
    for img_row_idx, img_data_key in enumerate(img_dataset["data_key"]):
        _, img_id = img_data_key.split("#")  # data_key = test#89
        imgId_2_imgRowIdx[int(img_id)] = int(img_row_idx)

    # 以数量较少的数据集为基准
    img_ids_in_img_ds = set(imgId_2_imgRowIdx.keys())
    img_ids_in_graph_ds = set(imgId_2_graphRowIdx.keys())
    intersection_ids = img_ids_in_img_ds.intersection(img_ids_in_graph_ds)

    # 按照 img_id 的顺序，将 img_ds 的数据拼接到 graph_ds 的数据中
    filtered_img_ds = img_dataset.select([imgId_2_imgRowIdx[img_id] for img_id in intersection_ids])
    filtered_graph_ds = graph_dataset.select([imgId_2_graphRowIdx[img_id] for img_id in intersection_ids])
    merged_ds = concatenate_datasets([filtered_img_ds, filtered_graph_ds], axis=1)
    return merged_ds

In [4]:
CONFIG = {'output_name': '7_disease_features_pretrain_42obs_without_text_111_10-4', 'output_dir': {'result': '/scratch/c.c21051562/workspace/arrg_img2text/outputs/results/7_disease_features_pretrain_42obs_without_text_111_10-4', 'model': '/scratch/c.c21051562/workspace/arrg_img2text/outputs/models/7_disease_features_pretrain_42obs_without_text_111_10-4', 'checkpoint': '/scratch/c.c21051562/workspace/arrg_img2text/outputs/checkpoints/7_disease_features_pretrain_42obs_without_text_111_10-4'}, 'data_path': {'mimic': '/scratch/c.c21051562/resources/data/mimic-cxr', 'interpret': '/scratch/c.c21051562/resources/data/interpret-cxr', 'interpret-test-public': '/scratch/c.c21051562/resources/data/interpret-cxr-test-public', 'text_graph': '/scratch/c.c21051562/resources/data/interpret_disease'}, 'target_section': 'findings', 'target_observation': ['effusion', 'pneumothorax', 'opacity', 'normal'], 'model_name_or_path': {'clip': '/scratch/c.c21051562/resources/downloaded_models/clip-vit-base-patch32', 'swinv2': '/scratch/c.c21051562/resources/downloaded_models/swinv2-base-patch4-window8-256', 'rad_dino_maira2': '/scratch/c.c21051562/resources/downloaded_models/rad-dino-maira-2', 'llama32_1b': '/scratch/c.c21051562/resources/downloaded_models/Llama-3.2-1B'}, 'mlflow_url': 'http://localhost:6006', 'mlflow_port': '6006', 'max_checkpoints_to_keep': 1, 'resume_from_checkpoint': False, 'use_debug_subset': False, 'run_mode': 'pretrain', 'preprocess': {'image_processor': 'rad_dino_maira2', 'text_processor': 'llama32_1b', 'cache_path': '/scratch/c.c21051562/workspace/arrg_img2text/dataset_cache/interpretcxr_full_text_img518', 'batched': True, 'batch_size': 64, 'num_proc': 16}, 'model': {'vision_model': 'rad_dino_maira2', 'language_model': 'llama32_1b', 'chat_template': '/scratch/c.c21051562/workspace/arrg_img2text/llama3_chat_template7.jinja'}, 'pretrain': {'classification_only': True, 'seed': 42, 'num_epochs': 1, 'batch_size': 1, 'grad_accum_steps': 1, 'warmup_proportion': 0.1, 'lr': 0.0001, 'clip_grad_norm': 1.0, 'mixed_precision': 'bf16', 'print_loss_per_n_steps': 200, 'ckp_per_steps': 10000, 'eval_batch_size': 1, 'max_new_tokens': 512, 'print_pred_per_n_steps': 500, 'eval_valid_split': False, 'num_beams': 3}, 'finetune': {'use_pretrained': False, 'pretain_model_path': '/scratch/c.c21051562/workspace/arrg_img2text/outputs/models/4_1_fsdo_peft_test_pretrain', 'seed': 42, 'num_epochs': 1, 'batch_size': 1, 'grad_accum_steps': 1, 'warmup_proportion': 0.1, 'lr': 0.0001, 'clip_grad_norm': 1.0, 'mixed_precision': 'bf16', 'print_loss_per_n_steps': 200, 'ckp_per_steps': 10000, 'eval_batch_size': 1, 'max_new_tokens': 512, 'print_pred_per_n_steps': 500, 'eval_valid_split': False, 'num_beams': 3}, 'obs_classification_map': ['', 'mentioned', 'absent'], 'observation_map': ['effusion', 'pneumothorax', 'opacity', 'normal', 'consolidation', 'edema', 'atelectasis', 'tube', 'clear', 'catheter', 'pneumonia', 'infiltrate', 'pathophysiologic finding', 'infection', 'congestion', 'enlargement', 'wire', 'degeneration', 'fracture', 'thickening', 'pacemaker', 'emphysema', 'surgical drain', 'surgical clip', 'medical device', 'scoliosis', 'valve', 'chronic obstructive pulmonary disease', 'calcification', 'cirrhosis-associated nodules', 'atherosclerosis', 'calcifications', 'deformity', 'hernia', 'scar', 'pulmonary nodule', 'granuloma', 'automated implantable cardiac defibrillator', 'prosthesis', 'collapse', 'reticular pattern', 'heart failure'], 'jobid': 8081282, 'classification_only': True}

In [7]:
ds_img_path=CONFIG["preprocess"]["cache_path"]
ds_graph_path=CONFIG["data_path"]["text_graph"]
target_section=CONFIG["target_section"]


ds_img = load_image_datasets(data_paths=CONFIG["data_path"])
for data_split in ["train", "validation", "test"]:
    img_dataset = ds_img[data_split]
    img_dataset = img_dataset.add_column("data_key", [f"{data_split}#{idx}" for idx in range(len(img_dataset))])
    ds_img[data_split] = img_dataset

# ds_img = load_from_disk(ds_img_path)

ds_graph_path = os.path.join(ds_graph_path, f"interpret_disease_{target_section}")
ds_graph = load_from_disk(ds_graph_path)

ds_dict = {}
for split in ["train", "validation", "test"]:
    ds_dict[split] = merge_dataset(img_dataset=ds_img[split], graph_dataset=ds_graph[split])

ds_final = DatasetDict(ds_dict)

In [8]:
ds_final

DatasetDict({
    train: Dataset({
        features: ['source', 'images_path', 'images', 'impression', 'findings', 'data_key', 'doc_key', 'split_sents', 'radlex_types', 'radlex_to_splitsents_map'],
        num_rows: 344394
    })
    validation: Dataset({
        features: ['source', 'images_path', 'images', 'impression', 'findings', 'data_key', 'doc_key', 'split_sents', 'radlex_types', 'radlex_to_splitsents_map'],
        num_rows: 8839
    })
    test: Dataset({
        features: ['images', 'findings', 'impression', 'data_key', 'doc_key', 'split_sents', 'radlex_types', 'radlex_to_splitsents_map'],
        num_rows: 2692
    })
})

In [15]:
for idx, item in enumerate(ds_img["train"]):
    if len(item["images"]) > 2:
        print(idx)

26
40
68
69
130
283
349


KeyboardInterrupt: 