In [None]:

import os
import mlflow
import mlflow.pytorch
import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image
from sklearn.metrics import f1_score
from ray import train, tune
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer

In [None]:
class CheXpertDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        df = pd.read_csv(csv_file)
        self.image_paths = df['image_path'].values
        self.labels = df.drop(columns=['image_path']).values.astype('float32')
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        img = self.transform(img)
        label = torch.tensor(self.labels[idx])
        return img, label

In [None]:

def get_model():
    base = models.densenet121(pretrained=True)
    num_ftrs = base.classifier.in_features
    base.classifier = nn.Sequential(
        nn.Linear(num_ftrs, 14),
        nn.Sigmoid()  # For multilabel classification
    )
    return base

In [None]:
def train_loop(config):
    import torch.optim as optim

    mlflow.set_tracking_uri(config["mlflow_uri"])
    mlflow.set_experiment("CheXpert_DenseNet121")

    model = get_model().to(train.torch.get_device())
    optimizer = optim.Adam(model.parameters(), lr=config["lr"])
    criterion = nn.BCELoss()

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])

    train_data = CheXpertDataset(config["csv_path"], transform=transform)
    loader = DataLoader(train_data, batch_size=config["batch_size"], shuffle=True)

    mlflow.start_run()

    for epoch in range(config["epochs"]):
        model.train()
        total_loss = 0
        all_preds, all_labels = [], []
        for images, labels in loader:
            images = images.to(train.torch.get_device())
            labels = labels.to(train.torch.get_device())

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            all_preds.append(outputs.detach().cpu())
            all_labels.append(labels.cpu())

        avg_loss = total_loss / len(loader)
        preds = torch.cat(all_preds).numpy()
        targets = torch.cat(all_labels).numpy()
        f1 = f1_score(targets, preds > 0.5, average='macro')

        mlflow.log_metric("loss", avg_loss, step=epoch)
        mlflow.log_metric("f1_score", f1, step=epoch)

    mlflow.pytorch.log_model(model, "model")
    mlflow.end_run()

In [None]:
config = {
    "csv_path": "train_labels.csv",  # Your CSV file path
    "epochs": 5,
    "lr": 1e-4,
    "batch_size": 16,
    "mlflow_uri": "http://<your-cpu-ip>:8000"  # Set your MLflow tracking URI
}

trainer = TorchTrainer(
    train_loop_per_worker=train_loop,
    train_loop_config=config,
    scaling_config=ScalingConfig(num_workers=1, use_gpu=True),
)

trainer.fit()
