<a href="https://colab.research.google.com/github/ML4SCI/DeepLearnHackathon/blob/main/ExoplanetSearchChallenge/Exoplanet_Search_Challenge.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Exoplanet Search Challenge

<div style="text-align: center;">
  <img src="https://upload.wikimedia.org/wikipedia/commons/9/9d/HL_Tau_protoplanetary_disk.jpg" alt="HL Tau" width="500" height="500">
</div>

Credit: ALMA (ESO/NAOJ/NRAO)


## Detecting Exoplanets in Protoplanetary Disks

### Description
Protoplanetary disks are the sites of planet formation. They provide laboratories against which theories of planet formation can be tested. State-of-the-art telescopes have the power to observe these systems in unprecedented detail. These observations can contain a wealth of information that can be used to advance theories. However, extracting this information can be difficult since the observations are noisy, and there are few well-understood disks. Recently, the interplay of advanced simulations and machine learning have been successful in analyzing these disks and identifying exoplanets [[1](https://ui.adsabs.harvard.edu/abs/2021ApJ...920....3A/abstract)] [[2](https://ui.adsabs.harvard.edu/abs/2022MNRAS.510.4473Z/abstract)] [[3](https://ui.adsabs.harvard.edu/abs/2022ApJ...941..192T/abstract)] [[4](https://ui.adsabs.harvard.edu/abs/2023ApJ...947...60T/abstract)]. This promising avenue of research is the basis for this Hackathon challenge.

### Task
The task is to train a model that is capable of identifying if a synthetic observation contains a planet. This is a binary classification problem: planet or no planet.

### Datasets
The data used was generated for [Terry et al. (2022)](https://ui.adsabs.harvard.edu/abs/2022ApJ...941..192T/abstract). It consists of .fits files that represent synthetic continuum observations of protoplanetary disks at 1250 microns. Each simulation, for which there may be several snapshots, consists of a disk with between 0-4 planets. Data includes [the full training dataset](https://drive.google.com/drive/folders/1BV8FksW_EZnLTWUeHwJ_fEctFgjVbhMp?usp=drive_link), [a subset of the training data](https://drive.google.com/file/d/1I0JS1Qd896BGgsPcga3umQm-RuJB37UA/view?usp=drive_link), and [the training labels](https://drive.google.com/file/d/1gtBi4ILvCe8nTF09p_E9WWMplTQGC2Wr/view?usp=drive_link). The labels correspond to the simulation number, e.g., planet0_xxxx.fits corresponds to run 0. Each .fits file comes with 4 channels, but only the first one is relevant. This example only uses the small training subset since this is meant for speed and clarity rather than performance. Final training should be done on the entire training set.

### Evaluation Metrics
* AUC for withheld test set that will not be given
* Performance on real observations

## Deliverables
* You are required to submit a Google Colab Jupyter Notebook clearly showing your implementation along with the above-mentioned evaluation metrics for test data. This test data is part of the provided data, but it should not be used as training or validation data. It is not the same data that we will test on.
* A PDF of your final Jupyter notebook
* You must also submit the final trained model, including the model architecture and the trained weights ( For example: HDF5 file, .pb file, .pt file, etc. ) that can be easily implemented on our withheld data.


# Imports

In [None]:
# used for downloading data
!pip install gdown

In [None]:
# used for model
!pip install pytorch_lightning

In [None]:
import os

from astropy.io import fits

import gdown

from matplotlib.colors import LogNorm
import matplotlib.pyplot as plt

import numpy as np

import pandas as pd

import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, progress
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

from sklearn.metrics import roc_curve, roc_auc_score
from sklearn.model_selection import train_test_split

import torch
from torch import nn, Tensor
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader, Dataset, RandomSampler, random_split

import torchmetrics

import torchvision
import torchvision.transforms as T

In [None]:
torch.manual_seed(123)
np.random.seed(123)

# Load data

The below method downloads the data from Google Drive. This will be very slow for when using the entire dataset, so it is recommended that the data is added to your personal Google Drive and mount it using code similar to that below (or do it locally)

```python
from google.colab import drive
drive.mount("/content/drive")
data_dir = "/content/drive/My Drive/Full_Train_Data/"
data_names = os.listdir(data_dir)
label_name = "/content/drive/My Drive/train_info.csv"
label_df = pd.read_csv(label_name, usecols=range(1, 11))
```

In [None]:
### Sample data subset (NOT full dataset)
### The dataset used in this example is a very small subset for the sake of speed
### Using only this data would severely overtrain the models
### For deployment, the entire training data folder should be used
# data_id = "15AMGfgEu2ltGZN3rVtMV97mSbF2USrs6" ## for full set
data_id = "1I0JS1Qd896BGgsPcga3umQm-RuJB37UA"
gdown.download(f"https://drive.google.com/uc?id={data_id}", "data_names.zip", quiet=False)

In [None]:
# unzip the downloaded data folder
!unzip -q data_names.zip -d data_names

In [None]:
# load fits files
data_dir = "data_names/Sample_Data/" ### Not full dataset
# data_dir = "data_names/Full_Train_Data/" ### Full dataset
data_names = os.listdir(data_dir)
# make sure there aren't any weird files in the folder
data_names = np.array([x for x in data_names if ".fits" in x])

In [None]:
# get run information
run_nums = np.array([int(x.split("planet")[1].split("_")[0]) for x in data_names])

In [None]:
# sort by run number
order = np.argsort(run_nums)
run_nums = run_nums[order]
data_names = data_names[order]

In [None]:
# Download labels
label_id = "1gtBi4ILvCe8nTF09p_E9WWMplTQGC2Wr"
gdown.download(f"https://drive.google.com/uc?id={label_id}", "label_name.csv", quiet=False)

In [None]:
# Load labels
label_name = "label_name.csv"
label_df = pd.read_csv(label_name, usecols=range(1, 11))
label_df.head()
# (run, number of planets, mass of planet 1, semimajor axis of planet 1, ....)

In [None]:
runs = label_df.run.to_numpy()
Ns = label_df.n.to_numpy()

In [None]:
# label whether it's a planet or not a planet
labels = {}
nums = {}
for (name, run) in zip(data_names, run_nums):
    label = Ns[np.where(runs == run)][0]
    nums[name] = int(label)
    labels[name] = int(label > 0)

In [None]:
# Get actual data
data = {}
for name in data_names:
    # There are 4 channels, but we only care about the first
    data[name] = fits.open(f"{data_dir}{name}")[0].data.squeeze()[0]
    # normalize
    data[name] -= np.min(data[name])
    data[name] /= np.max(data[name])

In [None]:
# show some data
# some of the data has zero values due to the orientation
fig, axs = plt.subplots(ncols=2, nrows=2, figsize=((14, 14)))

img_index = np.random.randint(0, len(data_names), size=4)

axs[0, 0].imshow(data[data_names[img_index[0]]],
          origin="lower",
          cmap="magma",
          norm=LogNorm(vmin=1e-6, vmax=1),
          )
axs[0, 1].imshow(data[data_names[img_index[1]]],
          origin="lower",
          cmap="magma",
          norm=LogNorm(vmin=1e-6, vmax=1),
          )
axs[1, 0].imshow(data[data_names[img_index[2]]],
          origin="lower",
          cmap="magma",
          norm=LogNorm(vmin=1e-6, vmax=1),
          )
axs[1, 1].imshow(data[data_names[img_index[3]]],
          origin="lower",
          cmap="magma",
          norm=LogNorm(vmin=1e-6, vmax=1),
          )

axs[0, 0].set_title(f"{data_names[img_index[0]]} ({nums[data_names[img_index[0]]]} planets)")
axs[0, 1].set_title(f"{data_names[img_index[1]]} ({nums[data_names[img_index[1]]]} planets)")
axs[1, 0].set_title(f"{data_names[img_index[2]]} ({nums[data_names[img_index[2]]]} planets)")
axs[1, 1].set_title(f"{data_names[img_index[3]]} ({nums[data_names[img_index[3]]]} planets)")

plt.show()

In [None]:
xy_dim = data[data_names[0]].shape[0]

In [None]:
# Initialize the data arrays
X = np.empty((len(data_names), xy_dim, xy_dim))
y = np.empty((len(data_names), 1))

In [None]:
# load images and labels
for i, name in enumerate(data_names):
    X[i, :, :] = data[name]
    y[i, 0] = labels[name]
# add a channel axis
X = X[:, np.newaxis, :, :]

In [None]:
X = X.astype(np.float32)
y = y.astype(np.float32)

In [None]:
# split into train/test/val
##### Report the AUC of the created test data
test_split = 0.2
val_split = 0.2

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_split)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=val_split)

# Make datasets/loaders

In [None]:
class DiskDataset(Dataset):

    """Data loader"""

    def __init__(
        self,
        X: np.ndarray,
        y: np.ndarray,
        transform: list = None,
        accelerator_name: str = "mps",
    ) -> None:

        self.X = X
        self.y = y
        self.transform = transform

        if accelerator_name == "mps":
            self.device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
        elif accelerator_name == "cuda:0":
            self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    def __len__(self) -> int:
        return len(self.X)

    def __getitem__(self, idx) -> torch.Tensor:
        x_, y_ = self.X[idx], self.y[idx]

        x_, y_ = torch.from_numpy(x_),\
                torch.from_numpy(y_)

        if self.transform:
            x_ = self.transform(x_)
        return x_.to(self.device), y_.to(self.device)

In [None]:
# need to resize for EffnetV2
input_size = 224
transform = T.Compose([
                        T.Resize((input_size, input_size), antialias=True),
                        T.Normalize(mean=[0.5], std=[0.5]),
])

In [None]:
##### Now we actually make the dataset and dataloader in PyTorch fashion
train_data = DiskDataset(X_train, y_train, transform=transform)
val_data = DiskDataset(X_val, y_val, transform=transform)
test_data = DiskDataset(X_test, y_test, transform=transform)

In [None]:
# this is artificially small due to the tiny amount of data in this example
batch_size = 16

# make the loader
train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size)
val_loader = DataLoader(val_data, batch_size=batch_size)

# Make model

This model is a simple implementation of torchvision's EfficientNetV2

In [None]:
class CustomEfficientNetV2(pl.LightningModule):
    def __init__(self,
                 num_channels: int = 1,
                 num_outputs: int = 1,
                 lr: float = 5e-4,
                 xy_dim: int = 224,
                ):
        super().__init__()
        self.save_hyperparameters()

        # Load EfficientNetV2 model
        self.model = torchvision.models.efficientnet_v2_s()

        # Modify the first convolutional layer if input channels are different from 3
        if num_channels != 3:
            self.model.features[0][0] = nn.Conv2d(num_channels,
                                                  self.model.features[0][0].out_channels,
                                                  kernel_size=self.model.features[0][0].kernel_size,
                                                  stride=self.model.features[0][0].stride,
                                                  padding=self.model.features[0][0].padding,
                                                  bias=False,
                                                 )

        # Modify the final fully connected layer
        in_features = self.model.classifier[1].in_features
        self.model.classifier[1] = nn.Linear(in_features, num_outputs)

        self.criterion = nn.BCEWithLogitsLoss()

        # Initialize containers to store outputs
        self.validation_outputs = []
        self.test_outputs = []

        self.example_input_array = torch.randn((1, num_channels, xy_dim, xy_dim)).float()

    def forward(self, x):
        return self.model(x)

    def _process_batch(self, batch, when: str = "train"):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        self.log(f"{when}_loss", loss)
        if when != "train":
            return {f"{when}_loss": loss, "y_hat": y_hat, "y": y}
        return loss

    def training_step(self, batch, batch_idx):
        return self._process_batch(batch, when="train")

    def validation_step(self, batch, batch_idx):
        outputs = self._process_batch(batch, when="val")
        self.validation_outputs.append(outputs)
        return outputs

    def test_step(self, batch, batch_idx):
        outputs = self._process_batch(batch, when="test")
        self.test_outputs.append(outputs)
        return outputs

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(),
                                lr=self.hparams.lr,
                               )

    def _roc_epoch_end(self, outputs, when: str = "val"):
        """Logs AUC during validation/testing"""
        y_hat = torch.cat([x["y_hat"] for x in outputs]).detach().cpu().numpy()
        y = torch.cat([x["y"] for x in outputs]).detach().cpu().numpy()
        auc = self.calculate_auc(y_hat, y)
        self.log(f"{when}_auc", auc)

    def on_validation_epoch_end(self,):
        self._roc_epoch_end(self.validation_outputs, when="val")
        self.validation_outputs.clear()

    def on_test_epoch_end(self,):
        self._roc_epoch_end(self.test_outputs, when="test")
        self.test_outputs.clear()

    def calculate_auc(self, y_hat, y):
        # Apply sigmoid to predictions if using BCEWithLogitsLoss
        y_hat = torch.sigmoid(torch.tensor(y_hat).float()).numpy()
        auc = roc_auc_score(y, y_hat)
        return auc.astype(np.float32)


In [None]:
model = CustomEfficientNetV2()

# Train

Full deployment will take a long time if GPUs aren't used.

In [None]:
accelerator_name = "cuda:0"

if accelerator_name == "mps":
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
elif accelerator_name == "cuda:0":
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
else:
    device = torch.device("cpu")

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False

In [None]:
#### necessary for newer PTL versions
devices = 1
accelerator = "gpu" if devices == 1 else "cpu"

In [None]:
#### This is artificially small for the purposes of speed
num_epochs = 5
# make the trainer
trainer = pl.Trainer(
    devices=devices,
    accelerator=accelerator,
    max_epochs=num_epochs,
    log_every_n_steps=1,
    callbacks=[
        LearningRateMonitor("epoch"),
        progress.TQDMProgressBar(refresh_rate=1),
        EarlyStopping(
            monitor="val_auc",
            min_delta=0,
            patience=20,
            verbose=False,
            mode="min",
        ),
    ],
)
trainer.logger._log_graph = True
trainer.logger._default_hp_metric = None

In [None]:
model = model.to(device)

# fit the model
trainer.fit(
    model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
)

# Test

This is the AUC that will be judged, i.e., provided data that isn't used in training.

In [None]:
model.eval()
print("Testing model")

In [None]:
# Make a dataloader for the test data
test_batch_size = 16
test_loader = DataLoader(test_data, batch_size=test_batch_size)

In [None]:
# do inference on all test batches
results = []

for X_batch in test_loader:
    X_batch = X_batch[0]  # DataLoader returns a tuple
    with torch.no_grad():
        outputs = torch.sigmoid(model(X_batch))
    batch_results = outputs.detach().numpy().squeeze()
    results.append(batch_results)

y_pred = np.concatenate(results, axis=0)

In [None]:
# get ROC curve/AUC
fpr, tpr, _ = roc_curve(y_test, y_pred)
auc = roc_auc_score(y_test, y_pred)
accuracy = np.sum([round(y_pred[i]) == y_test[i] for i in range(len(y_test))]) / len(y_test)

print(f"Accuracy of {accuracy:.2}. AUC of {auc}")

In [None]:
# Plot ROC curve
plt.figure(figsize=(10., 7.5))

plt.plot(fpr, tpr, lw=3, c="steelblue")
plt.plot(np.linspace(0, 1, 100), np.linspace(0, 1, 100),
         c="gray", ls="--", alpha=0.5, lw=3,
         )

plt.xlabel("FPR", fontsize=14)
plt.ylabel("TRP", fontsize=14)

plt.xticks(fontsize=12)
plt.yticks(fontsize=12)

plt.show()