# Setup and load data

## Install dependencies and setup

In [None]:
!pip3 install tensorflow-macos tensorflow-metal opencv-python matplotlib panda

In [None]:
import tensorflow as tf
import os
import pandas as pd
from tensorflow.keras.applications import ResNet50
from tensorflow.keras import layers, models
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
import matplotlib.pyplot as plt
from tensorflow.keras import optimizers

### Check whether gpu available or not

In [None]:
gpus = tf.config.list_physical_devices("GPU")
if gpus:
    print("✅ Using GPU:", gpus)
else:
    print("❌ No GPU found, using CPU")

## Load Labels

In [None]:
# Load CSV
train_df = pd.read_csv("CheXpert-v1.0-small/train.csv")
valid_df = pd.read_csv("CheXpert-v1.0-small/valid.csv")

LABELS = ["Pneumothorax", "Pneumonia", "Edema", "Pleural Effusion", "Consolidation", "Cardiomegaly", "Atelectasis"]

# making dataset binary
train_df[LABELS] = train_df[LABELS].fillna(0).replace(-1, 0)
valid_df[LABELS] = valid_df[LABELS].fillna(0).replace(-1, 0)

# Data Augmentation

In [None]:
def augment_image(image, label):
    # Random brightness adjustment
    image = tf.image.random_brightness(image, max_delta=0.1)
    
    # Random contrast adjustment
    image = tf.image.random_contrast(image, lower=0.9, upper=1.1)
    
    # Random flip (horizontal only)
    image = tf.image.random_flip_left_right(image)
    
    # Random rotation
    angle = tf.random.uniform([], minval=-0.1, maxval=0.1)  # ±~5.7 degrees
    image = tf.keras.preprocessing.image.apply_affine_transform(
        image.numpy(), theta=angle * 180 / np.pi, fill_mode='constant'
    )
    image = tf.convert_to_tensor(image)
    
    # Ensure values stay in [0,1] range
    image = tf.clip_by_value(image, 0.0, 1.0)
    
    return image, label

## Dataset for tensorflow

In [None]:
IMG_SIZE = 224
BATCH_SIZE = 32

def parse_image(filename, label):
    img=tf.io.read_file(filename)
    # decoding jpg to tensor in 3 channels
    img=tf.image.decode_jpeg(img,channels=3)
    img=tf.image.resize(img, [IMG_SIZE, IMG_SIZE])
    # normalizing [0,255] to [0,1]
    img=tf.cast(img, tf.float32)/255.0
    return img, label

def df_to_dataset(df, base_dir="CheXpert-v1.0-small", training=False):
    # adding filepath to the img
    filepaths = df["Path"].apply(lambda x: os.path.join(base_dir, x))
    # only takes images with the labels defined above
    labels = df[LABELS].values.astype("float32")
    # makes a dataset that pairs each label with their image
    ds=tf.data.Dataset.from_tensor_slices((filepaths, labels))
    # apply parse image in parallel processing to make it faster
    ds=ds.map(parse_image, num_parallel_calls=tf.data.AUTOTUNE)
    if training:
        ds = ds.map(augment_image, num_parallel_calls=tf.data.AUTOTUNE)
    # shuffle the images and make it into batches
    ds=ds.shuffle(1000).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
    return ds

train_ds=df_to_dataset(train_df, training=True)
valid_ds=df_to_dataset(valid_df, training=False)

# Building the model

In [None]:
# ready-made brain --> resnet50 already knows detection of edges and shapes and textures so we resue that knowledge with our dataset
base_model = ResNet50(weights="imagenet", include_top=False, input_shape=(IMG_SIZE, IMG_SIZE, 3))
# freeze at start so it doesnt forget basic knowledsge
base_model.trainable = False

model = models.Sequential([
    # squashing all the detected features into vectors
    base_model, layers.GlobalAveragePooling2D(),
    # every iteration it turns some neurons off so that every neuron learns instead of relying on sm neurons
    layers.Dropout(0.5),
    # for every disease giving 0 and 1
    # using sigmoid instead of softmax for multilabeling
    layers.Dense(len(LABELS), activation="sigmoid")
])

# Calculate class weights
# increases the cost of getting something wrong wrt how rare they are so the model doesnt get a easy way out
class_totals = train_df[LABELS].sum().values
total_samples = len(train_df)
class_weights = total_samples / (len(LABELS) * class_totals)
class_weight_dict = {i: class_weights[i] for i in range(len(LABELS))}


# adam tell model how to update itself, loss --> tells the model how wrong it is, metrics --> keeps track of how well the model is doing
model.compile(
    optimizer="adam",
    loss="binary_crossentropy", 
    metrics=["accuracy"]
    )

# Adding callback

In [None]:
early_stop = EarlyStopping(
    monitor='val_loss',
    patience=5,
    restore_best_weights=True
)
checkpoint = ModelCheckpoint(
    "best_model.keras",
    monitor="val_loss",
    save_best_only=True
)

# Train

In [None]:
# model.fit --> training begins
history = model.fit(
    train_ds,
    validation_data=valid_ds,
    epochs=5,
    callbacks=[early_stop, checkpoint],
    class_weight=class_weight_dict)

## Unfreeze some layers

In [None]:
base_model.trainable=True

# finetuning
for layer in base_model.layers[:-50]:
    layer.trainable = False

## Recompiling with smaller learning rate

In [None]:
model.compile(
    optimizer=optimizers.Adam(learning_rate=1e-5),
    loss="binary_crossentropy",
    metrics=["accuracy"]
)

## Train

In [None]:
history_fine = model.fit(
    train_ds,
    validation_data=valid_ds,
    epochs=50,
    initial_epoch=history.epoch[-1],
    callbacks=[early_stop, checkpoint],
    class_weight=class_weight_dict
)

# Save the model

In [None]:
model.save("best_model.keras")

# Plotting the model stats

In [None]:
def plot_history(histories, titles):
    plt.figure(figsize=(14,5))
    
    # Plot Loss
    plt.subplot(1,2,1)
    for h, t in zip(histories, titles):
        plt.plot(h.history['loss'], label=f'{t} train loss')
        plt.plot(h.history['val_loss'], label=f'{t} val loss')
    plt.title('Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    
    # Plot Accuracy
    plt.subplot(1,2,2)
    for h, t in zip(histories, titles):
        plt.plot(h.history['accuracy'], label=f'{t} train acc')
        plt.plot(h.history['val_accuracy'], label=f'{t} val acc')
    plt.title('Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    
    plt.show()

# Call the function to plot
plot_history([history, history_fine], ['Initial', 'Fine-tuning'])