## Импорт зависимостей

In [1]:
!pip install medcam3d grad-cam -q

In [22]:
import os
import re
import torch
import torchvision
import pandas as pd
from torch import nn, optim
from transformers import GPT2Tokenizer, GPT2LMHeadModel, BertTokenizer, BertForSequenceClassification
from pytorch_grad_cam import GradCAM
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import classification_report
from sklearn.model_selection import GroupShuffleSplit

## Загрузка и предобработка данных

In [23]:
rpt = pd.read_csv('/kaggle/input/chest-xrays-indiana-university/indiana_reports.csv')
proj = pd.read_csv('/kaggle/input/chest-xrays-indiana-university/indiana_projections.csv')

In [24]:
df = proj.merge(rpt, on='uid', how='left')
df = df.dropna(subset=['findings','impression'])
df['report'] = df['findings'].fillna('') + ' ' + df['impression'].fillna('')

In [25]:
df["path"] = '/kaggle/input/chest-xrays-indiana-university/images/images_normalized/' + df["filename"]

In [26]:
df.head()

Unnamed: 0,uid,filename,projection,MeSH,Problems,image,indication,comparison,findings,impression,report,path
0,1,1_IM-0001-4001.dcm.png,Frontal,normal,normal,Xray Chest PA and Lateral,Positive TB test,None.,The cardiac silhouette and mediastinum size ar...,Normal chest x-XXXX.,The cardiac silhouette and mediastinum size ar...,/kaggle/input/chest-xrays-indiana-university/i...
1,1,1_IM-0001-3001.dcm.png,Lateral,normal,normal,Xray Chest PA and Lateral,Positive TB test,None.,The cardiac silhouette and mediastinum size ar...,Normal chest x-XXXX.,The cardiac silhouette and mediastinum size ar...,/kaggle/input/chest-xrays-indiana-university/i...
2,2,2_IM-0652-1001.dcm.png,Frontal,Cardiomegaly/borderline;Pulmonary Artery/enlarged,Cardiomegaly;Pulmonary Artery,"Chest, 2 views, frontal and lateral",Preop bariatric surgery.,None.,Borderline cardiomegaly. Midline sternotomy XX...,No acute pulmonary findings.,Borderline cardiomegaly. Midline sternotomy XX...,/kaggle/input/chest-xrays-indiana-university/i...
3,2,2_IM-0652-2001.dcm.png,Lateral,Cardiomegaly/borderline;Pulmonary Artery/enlarged,Cardiomegaly;Pulmonary Artery,"Chest, 2 views, frontal and lateral",Preop bariatric surgery.,None.,Borderline cardiomegaly. Midline sternotomy XX...,No acute pulmonary findings.,Borderline cardiomegaly. Midline sternotomy XX...,/kaggle/input/chest-xrays-indiana-university/i...
6,4,4_IM-2050-1001.dcm.png,Frontal,"Pulmonary Disease, Chronic Obstructive;Bullous...","Pulmonary Disease, Chronic Obstructive;Bullous...","PA and lateral views of the chest XXXX, XXXX a...",XXXX-year-old XXXX with XXXX.,None available,There are diffuse bilateral interstitial and a...,1. Bullous emphysema and interstitial fibrosis...,There are diffuse bilateral interstitial and a...,/kaggle/input/chest-xrays-indiana-university/i...


Разбиение на выборки

In [27]:
split = GroupShuffleSplit(test_size=0.2, n_splits=1, random_state=42)
train_idx, test_idx = next(split.split(df, groups=df['uid']))
train_df, test_df = df.iloc[train_idx], df.iloc[test_idx]

In [28]:
train_df.shape, test_df.shape

((5164, 12), (1293, 12))

### Парсинг MeSH → pathology теги

In [29]:
common_labels = [
    "atelectasis","cardiomegaly","consolidation","edema","pleural effusion",
    "pneumonia","pneumothorax","support devices","enlarged cardiomediastinum",
    "fracture","lung lesion","lung opacity","pleural other","pleural thickening", "normal"
]
label_map = {lbl: i for i,lbl in enumerate(common_labels)}

Построение label_map

In [30]:
def extract_labels(mesh_str):
    tags = []
    for term in re.split('[;/]', mesh_str.lower()):
        term = term.strip()
        for lbl in common_labels:
            if lbl in term:
                tags.append(lbl)
    return list(set(tags))


In [31]:
# df_labels = df[['uid','MeSH']].dropna()
df_labels = df[['uid','MeSH']].drop_duplicates().copy()

df_labels['labels'] = df_labels['MeSH'].apply(extract_labels)

mlb = MultiLabelBinarizer(classes=common_labels)
mlb.fit([common_labels])

In [32]:
label_map = {}
for uid, grp in df_labels.groupby('uid'):
    lbls = grp['labels'].sum()  # объединяем списки (хотя обычно один элемент)
    vect = mlb.transform([lbls])[0]
    label_map[uid] = torch.tensor(vect, dtype=torch.float)

In [33]:
# Проверяем, что есть метки для каждой группы:
miss = set(train_df['uid']) - set(label_map.keys())
print("UIDs without labels:", miss)
assert not miss, "Проверьте функцию extract_labels: не все UID получили метки"

UIDs without labels: set()


### DataLoader с фронтальными и латеральными проекциями

In [34]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms

In [35]:
class IUDataset(Dataset):
    def __init__(self, df, img_dir, label_map, transform=None):
        self.transform = transform
        self.groups = df.groupby('uid')
        self.uids = list(self.groups.groups.keys())
        self.img_dir = img_dir
        self.label_map = label_map

    def __len__(self):
        return len(self.uids)

    def __getitem__(self, idx):
        uid = self.uids[idx]
        group = self.groups.get_group(uid)
        # выбираем frontal
        if 'projection' in group.columns and 'frontal' in group['projection'].values:
            row = group[group['projection']=='frontal'].iloc[0]
        else:
            row = group.iloc[0]

        path = row['path']  # убедитесь, что колонка называется именно так
        image = Image.open(os.path.join(self.img_dir, path)).convert('RGB')
        if self.transform:
            image = self.transform(image)

        labels = self.label_map[uid]        # теперь это dict, а не Compose
        report = row['report']

        return {'img_f': image, 'report': report, 'labels': labels}


In [36]:
tfm = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

In [37]:
train_ds = IUDataset(train_df, '/kaggle/input/chest-xrays-indiana-university/images/images_normalized', transform=tfm, label_map=label_map)
train_dl = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=2)

test_ds = IUDataset(test_df, '/kaggle/input/chest-xrays-indiana-university/images/images_normalized', transform=tfm, label_map=label_map)
test_dl = DataLoader(test_ds, batch_size=16, shuffle=True, num_workers=2)

## Определение DenseTagger (CheXNet‑backbone)

In [39]:
class DenseTagger(nn.Module):
    def __init__(self, num_labels):
        super().__init__()
        self.backbone = torchvision.models.densenet121(pretrained=True)
        # Заменяем классификатор на multi-label
        in_features = self.backbone.classifier.in_features
        self.backbone.classifier = nn.Linear(in_features, num_labels)

    def forward(self, x):
        # x: [batch, 3, H, W]
        logits = self.backbone(x)
        return logits  # BCEWithLogitsLoss применит sigmoid внутри

In [40]:
# Гиперпараметры
num_labels = 15  # Число патологий в IU X-Ray (настройте по CSV)
batch_size = 16
lr = 1e-4
num_epochs = 5

In [41]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [42]:
# Инициализация модели, loss и оптимизатора
model = DenseTagger(num_labels=num_labels).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)



In [43]:
def train_epoch(model, loader, criterion, optimizer):
    """
    Обучение модели
    """
    model.train()
    total_loss = 0.0
    for batch in loader:
        imgs = batch['img_f'].to(device)  # используем frontal view
        # Если нужны оба вида: concat или два потока
        targets = batch['labels'].to(device)  # tensor [B, num_labels]

        optimizer.zero_grad()
        logits = model(imgs)
        loss = criterion(logits, targets)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * imgs.size(0)
    return total_loss / len(loader.dataset)

In [44]:
def validate_epoch(model, loader, criterion):
    """
    Валидация модели
    """
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for batch in loader:
            imgs = batch['img_f'].to(device)
            targets = batch['labels'].to(device)
            logits = model(imgs)
            loss = criterion(logits, targets)
            total_loss += loss.item() * imgs.size(0)
    return total_loss / len(loader.dataset)

Обучение

In [45]:
for epoch in range(1, num_epochs + 1):
    train_loss = train_epoch(model, train_dl, criterion, optimizer)
    val_loss = validate_epoch(model, test_dl, criterion)
    print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")

# Сохранение чекпойнта
torch.save(model.state_dict(), 'dense_tagger.pt')
print("Training complete. Model saved to dense_tagger.pt")

Epoch 1: Train Loss = 0.2005, Val Loss = 0.1084
Epoch 2: Train Loss = 0.0888, Val Loss = 0.0990
Epoch 3: Train Loss = 0.0697, Val Loss = 0.0981
Epoch 4: Train Loss = 0.0496, Val Loss = 0.1141
Epoch 5: Train Loss = 0.0335, Val Loss = 0.1181
Training complete. Model saved to dense_tagger.pt


In [46]:
model.eval()
with torch.no_grad():
    batch = next(iter(test_dl))
    imgs = batch['img_f'].to(device)
    logits = model(imgs)               # [B,14]
    probs = torch.sigmoid(logits)      # вероятности
    preds = (probs > 0.5).int()        # бинарные метки

print("Probs:\n", probs.cpu().numpy())
print("Preds:\n", preds.cpu().numpy())
print("GT:\n", batch['labels'].cpu().numpy())


Probs:
 [[5.2377232e-03 7.3896036e-02 3.9529526e-03 6.7159072e-03 3.1068497e-03
  7.7399062e-03 1.0152541e-03 2.4027477e-03 1.0218463e-03 6.6921189e-03
  1.1620955e-03 1.6949276e-03 1.2949926e-03 1.1127987e-03 4.1301730e-03]
 [4.0373523e-03 1.9283025e-01 3.7208796e-03 1.4749845e-02 1.5060047e-02
  5.4868641e-03 1.3661016e-03 2.5402396e-03 8.1273285e-04 1.6706083e-02
  1.2976187e-03 1.3232386e-03 1.3801036e-03 1.1620357e-03 2.6445766e-03]
 [7.3696452e-01 2.5527971e-02 7.0201731e-03 4.7813058e-03 2.4575172e-02
  7.6761944e-03 7.6276725e-03 1.5060474e-03 6.1185006e-04 1.5688503e-02
  1.0779948e-03 5.1496155e-04 1.7433182e-03 9.5444382e-04 6.3697244e-03]
 [1.9949256e-02 3.1900895e-03 2.3964492e-03 3.7114909e-03 3.4246254e-03
  4.9994681e-03 3.0515969e-03 3.4977414e-03 1.3011824e-03 3.9115339e-03
  2.5473454e-03 1.4386236e-03 1.9051324e-03 2.0626471e-03 9.0177590e-01]
 [5.5655502e-02 8.4581263e-03 5.0148577e-03 4.5600948e-03 6.1346581e-03
  8.0476739e-03 6.4141299e-03 3.2425066e-03 1.965424

In [47]:
model.eval()
with torch.no_grad():
    batch = next(iter(test_dl))
    imgs = batch['img_f'].to(device)
    probs = torch.sigmoid(model(imgs)).cpu().numpy()
    preds = (probs > 0.5).astype(int)

for i in range(len(imgs)):
    print(f"=== Sample {i} ===")
    print("Probabilities:", {common_labels[j]: round(float(probs[i,j]), 3) for j in range(len(common_labels))})
    print("Predicted:", [common_labels[j] for j in range(len(common_labels)) if preds[i,j] == 1])
    print("Ground‑Truth:", [common_labels[j] for j in range(len(common_labels)) if batch['labels'][i,j] == 1])
    print()

=== Sample 0 ===
Probabilities: {'atelectasis': 0.003, 'cardiomegaly': 0.005, 'consolidation': 0.001, 'edema': 0.003, 'pleural effusion': 0.001, 'pneumonia': 0.003, 'pneumothorax': 0.003, 'support devices': 0.004, 'enlarged cardiomediastinum': 0.002, 'fracture': 0.007, 'lung lesion': 0.001, 'lung opacity': 0.001, 'pleural other': 0.002, 'pleural thickening': 0.002, 'normal': 0.105}
Predicted: []
Ground‑Truth: ['normal']

=== Sample 1 ===
Probabilities: {'atelectasis': 0.002, 'cardiomegaly': 0.002, 'consolidation': 0.002, 'edema': 0.003, 'pleural effusion': 0.003, 'pneumonia': 0.003, 'pneumothorax': 0.003, 'support devices': 0.002, 'enlarged cardiomediastinum': 0.001, 'fracture': 0.002, 'lung lesion': 0.002, 'lung opacity': 0.001, 'pleural other': 0.001, 'pleural thickening': 0.002, 'normal': 0.799}
Predicted: ['normal']
Ground‑Truth: ['normal']

=== Sample 2 ===
Probabilities: {'atelectasis': 0.001, 'cardiomegaly': 0.002, 'consolidation': 0.002, 'edema': 0.004, 'pleural effusion': 0.00

In [48]:
all_preds = []
all_targets = []

model.eval()
with torch.no_grad():
    for batch in test_dl:
        probs = torch.sigmoid(model(batch['img_f'].to(device))).cpu().numpy()
        preds = (probs > 0.5).astype(int)
        all_preds.extend(preds)
        all_targets.extend(batch['labels'].cpu().numpy())

print(classification_report(all_targets, all_preds, target_names=common_labels))


                            precision    recall  f1-score   support

               atelectasis       0.16      0.05      0.08        55
              cardiomegaly       0.51      0.45      0.48        55
             consolidation       0.00      0.00      0.00         6
                     edema       0.00      0.00      0.00         6
          pleural effusion       0.62      0.33      0.43        24
                 pneumonia       0.00      0.00      0.00         8
              pneumothorax       0.00      0.00      0.00         3
           support devices       0.00      0.00      0.00         0
enlarged cardiomediastinum       0.00      0.00      0.00         0
                  fracture       0.00      0.00      0.00        20
               lung lesion       0.00      0.00      0.00         0
              lung opacity       0.00      0.00      0.00         0
             pleural other       0.00      0.00      0.00         0
        pleural thickening       0.00      0.00

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
