## Импорт зависимостей

In [1]:
!pip install medcam3d grad-cam -q

In [2]:
import os
import re
import torch
import torchvision
import pandas as pd
from torch import nn, optim
from transformers import GPT2Tokenizer, GPT2LMHeadModel, BertTokenizer, BertForSequenceClassification
from pytorch_grad_cam import GradCAM
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import classification_report
from sklearn.model_selection import GroupShuffleSplit

## Загрузка и предобработка данных

In [3]:
rpt = pd.read_csv('/kaggle/input/chest-xrays-indiana-university/indiana_reports.csv')
proj = pd.read_csv('/kaggle/input/chest-xrays-indiana-university/indiana_projections.csv')

In [4]:
df = proj.merge(rpt, on='uid', how='left')
df = df.dropna(subset=['findings','impression'])
df['report'] = df['findings'].fillna('') + ' ' + df['impression'].fillna('')

In [5]:
df["path"] = '/kaggle/input/chest-xrays-indiana-university/images/images_normalized/' + df["filename"]

In [6]:
df.head()

Unnamed: 0,uid,filename,projection,MeSH,Problems,image,indication,comparison,findings,impression,report,path
0,1,1_IM-0001-4001.dcm.png,Frontal,normal,normal,Xray Chest PA and Lateral,Positive TB test,None.,The cardiac silhouette and mediastinum size ar...,Normal chest x-XXXX.,The cardiac silhouette and mediastinum size ar...,/kaggle/input/chest-xrays-indiana-university/i...
1,1,1_IM-0001-3001.dcm.png,Lateral,normal,normal,Xray Chest PA and Lateral,Positive TB test,None.,The cardiac silhouette and mediastinum size ar...,Normal chest x-XXXX.,The cardiac silhouette and mediastinum size ar...,/kaggle/input/chest-xrays-indiana-university/i...
2,2,2_IM-0652-1001.dcm.png,Frontal,Cardiomegaly/borderline;Pulmonary Artery/enlarged,Cardiomegaly;Pulmonary Artery,"Chest, 2 views, frontal and lateral",Preop bariatric surgery.,None.,Borderline cardiomegaly. Midline sternotomy XX...,No acute pulmonary findings.,Borderline cardiomegaly. Midline sternotomy XX...,/kaggle/input/chest-xrays-indiana-university/i...
3,2,2_IM-0652-2001.dcm.png,Lateral,Cardiomegaly/borderline;Pulmonary Artery/enlarged,Cardiomegaly;Pulmonary Artery,"Chest, 2 views, frontal and lateral",Preop bariatric surgery.,None.,Borderline cardiomegaly. Midline sternotomy XX...,No acute pulmonary findings.,Borderline cardiomegaly. Midline sternotomy XX...,/kaggle/input/chest-xrays-indiana-university/i...
6,4,4_IM-2050-1001.dcm.png,Frontal,"Pulmonary Disease, Chronic Obstructive;Bullous...","Pulmonary Disease, Chronic Obstructive;Bullous...","PA and lateral views of the chest XXXX, XXXX a...",XXXX-year-old XXXX with XXXX.,None available,There are diffuse bilateral interstitial and a...,1. Bullous emphysema and interstitial fibrosis...,There are diffuse bilateral interstitial and a...,/kaggle/input/chest-xrays-indiana-university/i...


Разбиение на выборки

In [7]:
split = GroupShuffleSplit(test_size=0.2, n_splits=1, random_state=42)
train_idx, test_idx = next(split.split(df, groups=df['uid']))
train_df, test_df = df.iloc[train_idx], df.iloc[test_idx]

In [8]:
train_df.shape, test_df.shape

((5164, 12), (1293, 12))

### Парсинг MeSH → pathology теги

In [9]:
common_labels = [
    "atelectasis","cardiomegaly","consolidation","edema","pleural effusion",
    "pneumonia","pneumothorax","support devices","enlarged cardiomediastinum",
    "fracture","lung lesion","lung opacity","pleural other","pleural thickening", "normal"
]
label_map = {lbl: i for i,lbl in enumerate(common_labels)}

Построение label_map

In [10]:
def extract_labels(mesh_str):
    tags = []
    for term in re.split('[;/]', mesh_str.lower()):
        term = term.strip()
        for lbl in common_labels:
            if lbl in term:
                tags.append(lbl)
    return list(set(tags))


In [11]:
# df_labels = df[['uid','MeSH']].dropna()
df_labels = df[['uid','MeSH']].drop_duplicates().copy()

df_labels['labels'] = df_labels['MeSH'].apply(extract_labels)

mlb = MultiLabelBinarizer(classes=common_labels)
mlb.fit([common_labels])

In [12]:
label_map = {}
for uid, grp in df_labels.groupby('uid'):
    lbls = grp['labels'].sum()  # объединяем списки (хотя обычно один элемент)
    vect = mlb.transform([lbls])[0]
    label_map[uid] = torch.tensor(vect, dtype=torch.float)

In [13]:
# Проверяем, что есть метки для каждой группы:
miss = set(train_df['uid']) - set(label_map.keys())
print("UIDs without labels:", miss)
assert not miss, "Проверьте функцию extract_labels: не все UID получили метки"

UIDs without labels: set()


### DataLoader с фронтальными и латеральными проекциями

In [14]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms

In [15]:
class IUDataset(Dataset):
    def __init__(self, df, img_dir, label_map, transform=None):
        self.transform = transform
        self.groups = df.groupby('uid')
        self.uids = list(self.groups.groups.keys())
        self.img_dir = img_dir
        self.label_map = label_map

    def __len__(self):
        return len(self.uids)

    def __getitem__(self, idx):
        uid = self.uids[idx]
        group = self.groups.get_group(uid)
        # выбираем frontal
        if 'projection' in group.columns and 'frontal' in group['projection'].values:
            row = group[group['projection']=='frontal'].iloc[0]
        else:
            row = group.iloc[0]

        path = row['path']  # убедитесь, что колонка называется именно так
        image = Image.open(os.path.join(self.img_dir, path)).convert('RGB')
        if self.transform:
            image = self.transform(image)

        labels = self.label_map[uid]        # теперь это dict, а не Compose
        report = row['report']

        return {'img_f': image, 'report': report, 'labels': labels}


In [16]:
tfm = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

In [17]:
train_ds = IUDataset(train_df, '/kaggle/input/chest-xrays-indiana-university/images/images_normalized', transform=tfm, label_map=label_map)
train_dl = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=2)

test_ds = IUDataset(test_df, '/kaggle/input/chest-xrays-indiana-university/images/images_normalized', transform=tfm, label_map=label_map)
test_dl = DataLoader(test_ds, batch_size=16, shuffle=True, num_workers=2)

## Определение DenseTagger (CheXNet‑backbone)

In [18]:
class DenseTagger(nn.Module):
    def __init__(self, num_labels):
        super().__init__()
        self.backbone = torchvision.models.densenet121(pretrained=True)
        # Заменяем классификатор на multi-label
        in_features = self.backbone.classifier.in_features
        self.backbone.classifier = nn.Linear(in_features, num_labels)

    def forward(self, x):
        # x: [batch, 3, H, W]
        logits = self.backbone(x)
        return logits  # BCEWithLogitsLoss применит sigmoid внутри

In [19]:
# Гиперпараметры
num_labels = 15  # Число патологий в IU X-Ray (настройте по CSV)
batch_size = 16
lr = 1e-4
num_epochs = 5

In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [21]:
# Инициализация модели, loss и оптимизатора
model = DenseTagger(num_labels=num_labels).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)



In [22]:
def train_epoch(model, loader, criterion, optimizer):
    """
    Обучение модели
    """
    model.train()
    total_loss = 0.0
    for batch in loader:
        imgs = batch['img_f'].to(device)  # используем frontal view
        # Если нужны оба вида: concat или два потока
        targets = batch['labels'].to(device)  # tensor [B, num_labels]

        optimizer.zero_grad()
        logits = model(imgs)
        loss = criterion(logits, targets)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * imgs.size(0)
    return total_loss / len(loader.dataset)

In [23]:
def validate_epoch(model, loader, criterion):
    """
    Валидация модели
    """
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for batch in loader:
            imgs = batch['img_f'].to(device)
            targets = batch['labels'].to(device)
            logits = model(imgs)
            loss = criterion(logits, targets)
            total_loss += loss.item() * imgs.size(0)
    return total_loss / len(loader.dataset)

## Обучение с сохранением наилучшего чекпоинта по минимальной валидационной потере

In [24]:
best_val = float('inf')
best_epoch = -1

In [25]:
for epoch in range(1, num_epochs + 1):
    train_loss = train_epoch(model, train_dl, criterion, optimizer)
    val_loss = validate_epoch(model, test_dl, criterion)

    print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")

    if val_loss < best_val:
        best_val = val_loss
        best_epoch = epoch
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss
        }, 'best_dense_tagger.pt')
        print(f"=> New best model saved (epoch {epoch}, val_loss={val_loss:.4f})")

print(f"Training complete. Best val_loss {best_val:.4f} at epoch {best_epoch}")

Epoch 1: Train Loss = 0.2015, Val Loss = 0.1109
=> New best model saved (epoch 1, val_loss=0.1109)
Epoch 2: Train Loss = 0.0872, Val Loss = 0.1043
=> New best model saved (epoch 2, val_loss=0.1043)
Epoch 3: Train Loss = 0.0670, Val Loss = 0.1079
Epoch 4: Train Loss = 0.0454, Val Loss = 0.1125
Epoch 5: Train Loss = 0.0310, Val Loss = 0.1247
Training complete. Best val_loss 0.1043 at epoch 2


Загрузка наилучшей модели

In [26]:
ckpt = torch.load('best_dense_tagger.pt')
model.load_state_dict(ckpt['model_state_dict'])
optimizer.load_state_dict(ckpt['optimizer_state_dict'])
print(f"Loaded best model from epoch {ckpt['epoch']} with val_loss = {ckpt['val_loss']:.4f}")

Loaded best model from epoch 2 with val_loss = 0.1043


In [28]:
model.eval()
with torch.no_grad():
    batch = next(iter(test_dl))
    imgs = batch['img_f'].to(device)
    logits = model(imgs)               # [B,14]
    probs = torch.sigmoid(logits)      # вероятности
    preds = (probs > 0.5).int()        # бинарные метки

print("Probs:\n", probs.cpu().numpy())
print("Preds:\n", preds.cpu().numpy())
print("GT:\n", batch['labels'].cpu().numpy())


Probs:
 [[0.0489841  0.1184022  0.04343482 0.01905177 0.07111479 0.01833937
  0.01269776 0.00792821 0.00661284 0.0226193  0.01227041 0.01339292
  0.0135476  0.00783083 0.06678728]
 [0.6408981  0.72030294 0.02763935 0.05626715 0.15666099 0.02187985
  0.01857601 0.00527147 0.00690743 0.0251747  0.01274019 0.01302264
  0.00949994 0.00827656 0.01815701]
 [0.02212089 0.04957295 0.00935958 0.01136032 0.00628911 0.00567096
  0.00964182 0.00432027 0.0069657  0.00807972 0.00900637 0.01040506
  0.00812291 0.00747889 0.43422875]
 [0.02780366 0.01074267 0.0093627  0.00740412 0.00683731 0.01194001
  0.00813514 0.00614844 0.00449053 0.011826   0.00681223 0.00953462
  0.0057378  0.00492156 0.51197207]
 [0.00847246 0.00367822 0.01074441 0.00470991 0.00822416 0.00758066
  0.00620317 0.00593116 0.00568199 0.00911417 0.00626851 0.00639581
  0.00768462 0.00452718 0.37510377]
 [0.02738944 0.00602952 0.00750633 0.00441963 0.00481591 0.00794734
  0.00646436 0.00424769 0.00407445 0.0083392  0.00533111 0.00809

In [29]:
model.eval()
with torch.no_grad():
    batch = next(iter(test_dl))
    imgs = batch['img_f'].to(device)
    probs = torch.sigmoid(model(imgs)).cpu().numpy()
    preds = (probs > 0.5).astype(int)

for i in range(len(imgs)):
    print(f"=== Sample {i} ===")
    print("Probabilities:", {common_labels[j]: round(float(probs[i,j]), 3) for j in range(len(common_labels))})
    print("Predicted:", [common_labels[j] for j in range(len(common_labels)) if preds[i,j] == 1])
    print("Ground‑Truth:", [common_labels[j] for j in range(len(common_labels)) if batch['labels'][i,j] == 1])
    print()

=== Sample 0 ===
Probabilities: {'atelectasis': 0.057, 'cardiomegaly': 0.048, 'consolidation': 0.026, 'edema': 0.02, 'pleural effusion': 0.104, 'pneumonia': 0.023, 'pneumothorax': 0.016, 'support devices': 0.009, 'enlarged cardiomediastinum': 0.007, 'fracture': 0.021, 'lung lesion': 0.012, 'lung opacity': 0.012, 'pleural other': 0.011, 'pleural thickening': 0.009, 'normal': 0.055}
Predicted: []
Ground‑Truth: ['cardiomegaly']

=== Sample 1 ===
Probabilities: {'atelectasis': 0.045, 'cardiomegaly': 0.167, 'consolidation': 0.007, 'edema': 0.018, 'pleural effusion': 0.013, 'pneumonia': 0.009, 'pneumothorax': 0.006, 'support devices': 0.003, 'enlarged cardiomediastinum': 0.005, 'fracture': 0.01, 'lung lesion': 0.007, 'lung opacity': 0.008, 'pleural other': 0.006, 'pleural thickening': 0.006, 'normal': 0.093}
Predicted: []
Ground‑Truth: []

=== Sample 2 ===
Probabilities: {'atelectasis': 0.058, 'cardiomegaly': 0.029, 'consolidation': 0.026, 'edema': 0.009, 'pleural effusion': 0.035, 'pneumoni

In [31]:
all_preds = []
all_targets = []

model.eval()
with torch.no_grad():
    for batch in test_dl:
        probs = torch.sigmoid(model(batch['img_f'].to(device))).cpu().numpy()
        preds = (probs > 0.5).astype(int)
        all_preds.extend(preds)
        all_targets.extend(batch['labels'].cpu().numpy())

print(classification_report(all_targets, all_preds, target_names=common_labels))

                            precision    recall  f1-score   support

               atelectasis       0.15      0.04      0.06        55
              cardiomegaly       0.33      0.29      0.31        55
             consolidation       0.00      0.00      0.00         6
                     edema       0.00      0.00      0.00         6
          pleural effusion       0.50      0.08      0.14        24
                 pneumonia       0.00      0.00      0.00         8
              pneumothorax       0.00      0.00      0.00         3
           support devices       0.00      0.00      0.00         0
enlarged cardiomediastinum       0.00      0.00      0.00         0
                  fracture       0.00      0.00      0.00        20
               lung lesion       0.00      0.00      0.00         0
              lung opacity       0.00      0.00      0.00         0
             pleural other       0.00      0.00      0.00         0
        pleural thickening       0.00      0.00

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
