In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms, models
import pandas as pd
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from dataset import CXRDataset
import os

In [2]:

DATA_DIR = "../data/rsna/images"
LABELS_CSV = "../data/rsna/labels.csv"
SAVE_PATH = "model_weights/baseline.pt"

In [3]:
#Image Transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

In [4]:
# Load dataset
df = pd.read_csv(LABELS_CSV)
train_df, val_df = train_test_split(df, test_size=0.2, stratify=df['label'], random_state=42)

train_dataset = CXRDataset(train_df, DATA_DIR, transform)
val_dataset = CXRDataset(val_df, DATA_DIR, transform)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=2)

In [5]:

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [6]:
import warnings
warnings.filterwarnings("ignore")
# Model setup (DenseNet121)
model = models.densenet121(pretrained=True)
model.classifier = nn.Linear(model.classifier.in_features, 1)
model = model.to(device)

In [7]:
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [8]:
EPOCHS = 3  
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        imgs, labels = imgs.to(device), labels.to(device).float().unsqueeze(1)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{EPOCHS}]  Train Loss: {avg_loss:.4f}")

Epoch 1/3: 100%|██████████| 1512/1512 [04:48<00:00,  5.24it/s]


Epoch [1/3]  Train Loss: 0.4088


Epoch 2/3: 100%|██████████| 1512/1512 [04:47<00:00,  5.27it/s]


Epoch [2/3]  Train Loss: 0.3646


Epoch 3/3: 100%|██████████| 1512/1512 [04:46<00:00,  5.27it/s]

Epoch [3/3]  Train Loss: 0.3238





In [9]:
os.makedirs(os.path.dirname(SAVE_PATH), exist_ok=True)
torch.save(model.state_dict(), SAVE_PATH)
print(f"Model saved to {SAVE_PATH}")

Model saved to model_weights/baseline.pt
