# Case Study: Contrail Segmentation 
## On the Importance of Understanding Your Data and Task
*Rolls off the tongue, doesn't it?* - The Authors of This Notebook

This notebook presents a set of exercises in the form of a case study on segmenting contrails from satellite image (also referred to as remote sensing data). 

*Prerequisties*:
- Basic machine learning knowledge 
- Basic familiarity with the paper "Few-Shot Contrail Segmentation in Remote Sensing Imagery With Loss Function in Hough Space" by Junzi et al. (https://ieeexplore.ieee.org/document/10820969). We do not expect you to have read the whole paper, but rather to have an idea of what is being done.

*The goal by the end of this notebook is for students to:*
- Gain an exposure to a new application of artificial intelligence - extracing semantics from satellite images 
- Understand when they can consider the specifics of their data and task
- Gain exposure to methods to take advantage of their task
- Understand how to evaluate machine learning tasks based on the goal they are solving 
- Go beyond the technical understanding of the problem and consider where this application will be used, will it even be helpful, etc.

## Introduction: Contail Segmentation
The paper "Few-Shot Contrail Segmentation in Remote Sensing Imagery With Loss Function in Hough Space" by Junzi et al. focuses on creating an automatic segmentation procedure for contrails when using few data samples by taking advantage of what we know about the problem. In this section, the problem is introduced from the very basics of what a contrails is, to a mathematical formulation of the issue. 

<p align="center">
  <img src="images/what-is-a-contrail.png" width="800"/>
</p>

*What is a contrail?* 

Contrails or vapour trails are line-shaped clouds produced by aircraft engine exhaust or changes in air pressure, typically at aircraft cruising altitudes several kilometres/miles above the Earth's surface. (Definition - https://en.wikipedia.org/wiki/Contrail)

*Why does it matter?* 

The Sun emits solar radiation towards the Earth that the ground traditionally reflects back. However, the formation of the ice crystals in contrails creates a dense enough "shield" to reflect a part of them back. This impacts the amount of radiation trapped around the Earth and thus contributes to the temperature. They also have a potential cooling effect on sun rays getting reflected back towards the Sun, though this effect is currently estimated to be smaller.


**Satellite Image** 

Satellite images are a unique form of data. As the name implies, they are taken by a satellite, which leads to charactersitics about it, such as:
- **top-down**: the perspective of the image is always orthogonal towards the ground 
- **distance from the Earth**: Depending on the image, and satellite both the distance to the Earth and the resolution of the image is different. 
    - For this reason, the resolution is measured in terms of actual distance (e.g. A resolution of 30cm means that 30cm of information is encoded per pixel.)
    - Common resolutions are 30cm, 1m, 2m. Correspondingly a smaller distance corresponds to higher quality, more expense, harder to get, etc.
- **multispectral data**: Sensory information from satellites is more advanced than traditional cameras. They can capture more than just the color spectrum. Each one is a special band (channel in images). Below are listed some examples: 
    - RGB (Red-Green-Blue): The channels traditional cameras capture, and are blended together to create a final image
    - Infrared/Near-Infrared: A channel where some surfaces react differently to - for example, vegetation reacts differently to infrared light. Can reveal new information, not visible to the naked eye.
    - Point Clouds (LiDAR): A channel capturing distance to the earth. Used to map out things like elevation via the laser reflecting back. 
    - And many others, depending on the sensor used ... We will introduce them in this notebook if and as necessary.

<div style="display: flex; justify-content: space-around;">
   <img src="images/remote-sensing-platforms.webp" width="450"/>

</div>

[Image Source: GeeksForGeeks](https://imgs.search.brave.com/ebFlnbhNCKYFr760b0JRrC-BQlpJUeTmlB5SUw-5Nf4/rs:fit:860:0:0:0/g:ce/aHR0cHM6Ly9tZWRp/YS5nZWVrc2ZvcmdZWtzLm9yZy93cC)




Below are listed some example images of contrail images taken from different satellites and bands: 


<div style="display: flex; justify-content: space-around;">
  <img src="images/1-MeteoSat 11.png" width="350"/>
  <img src="images/2-NASA Terra.jpg" width="350"/>
  <img src="images/3-NOAA Suomi-NPP.jpg" width="350"/>
</div>

[Image Source TBD](TBD)


## Project Setup

### Create an Anaconda Environment and Download Requirements
If you haven't already setup the codebase following the instructions of the README, you can do so here. Otherwise, you can skip running this.

In [None]:
%conda create -y -n contrail-project python=3.12
%conda activate contrail-project
%conda install -y pip
%pip install -r requirements.txt

### Get Imports

In [None]:
import os
import random
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn, optim

def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

    # Force deterministic behavior in cudnn (might slow things down)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed()

## What is Contrail Segmentation?

What are we trying to accomplish? From the name **contrail segmentation**, and the information beforehand, we know that we are trying to extract information out of the satellite images. To do so, we need to define **segmentation** first.

#### Segmentation

Recall **classification**, abstracted from any machine learning model. Given an input $x$, we want to get an output $\widehat y$ corresponding to the correct class $y$ of the object $x$. 

In remote sensing, our inputs are images. Therefore $x \in \mathbb R^{W \times H \times C} $ where $W, H$ are the height and width (in pixels) of the current image and $C$ the number of channels total the image (recall RGB or Infrared introduced previously). Mathematically, this is a **tensor*(it is not very important for today, but good to know in general). 

In **segmentation**, we want to semantically understand our image. A way to do this is to somehow classify what we are seeing in our image. One way to do this is to output another image, classifying each pixel, whether it belongs to the object/s we are looking for. This can be defined as creating a function $f:x \in \mathbb R^{W \times H \times C} \to y \in \mathbb Z_n$ where $n$ is the number of classes we have in our image. An illustration is provided below. Specifically, this is a case called **semantic segmentation**.


<div style="display: flex; justify-content: space-around;">
   <img src="images/semantic_segmentation.jpg" width="750"/>
</div>

[Image Source](https://www.google.com/url?sa=i&url=https%3A%2F%2Fwww.hitechbpo.com%2Fblog%2Fsemantic-segmentation-guide.php&psig=AOvVaw2xoq84cC-FovXhPj2KUpDB&ust=1752526997439000&source=images&cd=vfe&opi=89978449&ved=0CBQQjRxqFwoTCNDgz_7duo4DFQAAAAAdAAAAABAE)


The case of **contrail segmentation** can then be examined as just a subset of **semantic segmentation$ where our classes are $n=2$ - "contrail" and "background". Illustration of $x$ and $y$ below:

<div style="display: flex; justify-content: space-around;">
   <img src="contrail-seg/data/goes/florida/image/florida_2020_03_05_0101.png" width="600"/>
   <img src="contrail-seg/data/goes/florida/mask/florida_2020_03_05_0101.png" width="600"/>
</div>

[Image Source TBD](TBD)

This is a challenging task. Throughout the study, the authors present ways they try to compensate for their lack of data (only around 30 images). **This is a common occurence in the ML industry**. In the following sections, you will go through exercises examining each of their approaches, try to add to them and apply them to a different case, and reason about if this is a correct application.

## How Do We Know Our Model Is Good? - Metrics

## Can We Train Our Model Better? - Other Loss Functions

## Data Scarcity - Data Augmentation

In [None]:
from torch import nn, optim

def train(model, dataloader, epochs=3):
    model.train()
    optimizer = optim.Adam(model.parameters())
    criterion = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        for images, labels in dataloader:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

def test(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in dataloader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

#### Natural Language


#### Images

In [None]:
from models.simpleCNN import SimpleCNN
from torch.utils.data import Subset
import torchvision 
import torchvision.transforms as transforms
import torch

def get_balanced_indices_by_targets(targets, samples_per_class):
    indices = []
    for class_label in range(10):
        class_indices = (targets == class_label).nonzero(as_tuple=True)[0]
        selected = class_indices[:samples_per_class]
        indices.extend(selected.tolist())
    return indices

transform_no_aug = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])


transform_aug = transforms.Compose([
    transforms.RandomRotation(15),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

subset_size = 100
epochs = 50 #WARNING - IF TOO LOW, MAY NOT WORK
full_trainset_no_aug = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform_no_aug)
full_trainset_aug = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform_aug)

samples_per_class = subset_size // 10
balanced_indices = get_balanced_indices_by_targets(full_trainset_no_aug.targets, samples_per_class)

trainset_no_aug = Subset(full_trainset_no_aug, balanced_indices)
trainset_aug = Subset(full_trainset_aug, balanced_indices)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_no_aug)

trainloader_no_aug = torch.utils.data.DataLoader(trainset_no_aug, batch_size=64, shuffle=True)
trainloader_aug = torch.utils.data.DataLoader(trainset_aug, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

model_no_aug = SimpleCNN()
train(model_no_aug, trainloader_no_aug, epochs=epochs)
acc_no_aug = test(model_no_aug, testloader)

model_aug = SimpleCNN()
train(model_aug, trainloader_aug, epochs=epochs)
acc_aug = test(model_aug, testloader)

print(f"Accuracy without augmentation: {acc_no_aug:.4f}")
print(f"Accuracy with augmentation: {acc_aug:.4f}")

#### Audio

## Can We Take Advantage of Other Tasks - Transfer Learning