# Assignment: Extreme Event Detection and Classification using Convolution and CNN

## Part 1: Data Loading, Visualization, and Labeling

### Step 1: Load and Extract Precipitation Data

We will start by loading the precipitation data from the NetCDF file. Then, we will plot a heatmap for the 180th day to visualize the precipitation on that day.

```python
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt

# Load the NetCDF file, you already have this data in the classroom
file_path = 'path_to_your_file/PERCDR_0.25deg_2001_2010_precipitation_data.nc'
dataset = xr.open_dataset(file_path)

# Extract the precipitation data
precipitation = dataset['precip'].values  # Replace 'precip' with the correct variable name if different
print("Data shape:", precipitation.shape)  # Should print the shape of the data

# Plot the heatmap for the 180th day
day_180 = precipitation[179]
plt.figure(figsize=(8, 6))
plt.contourf(day_180, cmap='Blues')
plt.colorbar(label='Precipitation (mm)')
plt.title('Precipitation Heatmap for the 180th Day')
plt.xlabel('Longitude Index')
plt.ylabel('Latitude Index')
plt.show()


### Step 2: Define Convolution Function
Next, we will provide you with a filter pattern and a threshold value. You will write the ```convolution_stride_pad``` function to apply this filter to the precipitation data.

```python
# Filter pattern
extreme_event_filter = np.array([[1, 1, 1],
                                [1, -8, 1],
                                [1, 1, 1]])

# Threshold value
extreme_threshold = 75

# Define the custom convolution function with stride and padding (same code as you saw in class)
def convolve_stride_pad(image, filter_pattern, stride=1, padding=0):
    if padding > 0:
        image = np.pad(image, ((padding, padding), (padding, padding)), mode='constant')

    rows, cols = image.shape
    filter_size = filter_pattern.shape[0]
    output_rows = (rows - filter_size) // stride + 1
    output_cols = (cols - filter_size) // stride + 1
    output = np.zeros((output_rows, output_cols))

    for i in range(0, rows - filter_size + 1, stride):
        for j in range(0, cols - filter_size + 1, stride):
            patch = image[i:i + filter_size, j:j + filter_size]
            output[i // stride, j // stride] = np.sum(patch * filter_pattern)
    
    return output

# Test the function with the 180th day data
conv_result = convolve_stride_pad(day_180, extreme_event_filter, stride=1, padding=1)
print("Convolved result shape:", conv_result.shape)

# Plot the convolved result as a heatmap
plt.figure(figsize=(8, 6))
plt.imshow(conv_result, cmap='hot', interpolation='nearest')
plt.colorbar(label='Convolution Result')
plt.title('Convolution Result Heatmap for the 180th Day')
plt.xlabel('Convolved Longitude Index')
plt.ylabel('Convolved Latitude Index')

# Add grid lines for better visualization
plt.grid(True, which='both', color='grey', linestyle='-', linewidth=0.5)

# Adjust ticks to show the grid-like structure more clearly
plt.xticks(np.arange(-.5, conv_result.shape[1], 1), labels=np.arange(0, conv_result.shape[1] + 1, 1))
plt.yticks(np.arange(-.5, conv_result.shape[0], 1), labels=np.arange(0, conv_result.shape[0] + 1, 1))
plt.gca().set_xticks(np.arange(-.5, conv_result.shape[1], 1), minor=True)
plt.gca().set_yticks(np.arange(-.5, conv_result.shape[0], 1), minor=True)
plt.gca().grid(which='minor', color='grey', linestyle='-', linewidth=0.5)

plt.show()

### Step 3: Label the Data and Plot Heatmap
We will now label each day based on the convolution result and the threshold value. Then, we will plot a heatmap showing the number of extreme events.
```python
# Function to label the data
def label_data(data, filter_pattern, threshold):
    labels = []
    for day in range(data.shape[0]):
        conv_result = convolve_stride_pad(data[day], filter_pattern, stride=1, padding=1)
        max_value = np.max(conv_result)
        label = 1 if max_value > threshold else 0
        labels.append(label)
    return np.array(labels)

# Label the data
labels = label_data(precipitation, extreme_event_filter, extreme_threshold)

# Print the number of extreme event days
print("Number of extreme event days:", np.sum(labels))

# Count the number of extreme events at each grid point
extreme_event_counts = np.zeros((precipitation.shape[1], precipitation.shape[2]))

for day in range(precipitation.shape[0]):
    if labels[day] == 1:
        extreme_event_counts += precipitation[day] > extreme_threshold

# Plot the heatmap of extreme events
plt.figure(figsize=(10, 8))
plt.contourf(extreme_event_counts, cmap='hot', levels=10)
plt.colorbar(label='Number of Extreme Events')
plt.xlabel('Longitude Index')
plt.ylabel('Latitude Index')
plt.title('Heatmap of Extreme Events')
plt.show()
```

## Part 2: Training and Testing a CNN Classifier
### Step 4: Data Preparation
We will provide the details for loading and preparing the data for training and testing.

```python
import torch
from torch.utils.data import Dataset, DataLoader

# Define a custom Dataset class for the labeled precipitation data
class PrecipitationDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        label = self.labels[idx]
        sample = sample[np.newaxis, :, :]  # Add channel dimension
        return torch.tensor(sample, dtype=torch.float32), torch.tensor(label, dtype=torch.float32)

# Split the data into training and testing sets
train_data = precipitation[:2922]  # First 8 years
train_labels = labels[:2922]
test_data = precipitation[2922:]  # Last 2 years
test_labels = labels[2922:]

# Create DataLoader instances
train_dataset = PrecipitationDataset(train_data, train_labels)
test_dataset = PrecipitationDataset(test_data, test_labels)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
```

### Step 5: Design the CNN Classifier
You will now design the ```CNNClassifier```. Here are the technical details:

- Conv1: Input channels = 1, Output channels = 16, Kernel size = 3, Padding = 1
- Activation: ReLU
- Pooling: MaxPool2d with Kernel size = 2, Stride = 2
- Conv2: Input channels = 16, Output channels = 32, Kernel size = 3, Padding = 1
- FC1: Input features = 32 * 3 * 4, Output features = 64
- FC2: Input features = 64, Output features = 1
- Activation: Sigmoid for the output

```python
# Your CNNClassifier model implementation here (Caution: The dimension and shape might create issues and irritate with error, try use squeeze and unsqueeze if necessary)

### Step 6: Train, Test, and Evaluate the Model
Here is the full code implementation to train, test and evaluate the model.

```python
import torch.optim as optim
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_curve, auc
import matplotlib.pyplot as plt

# Define loss function and optimizer
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 20
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs = inputs
        labels = labels.unsqueeze(1)  # Convert labels to shape [batch_size, 1]

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')

# Evaluate the model
model.eval()  # Set the model to evaluation mode
test_labels_list = []
pred_labels_list = []

with torch.no_grad():  # Disable gradient computation for evaluation
    for inputs, labels in test_loader:
        inputs = inputs  # Ensure correct input shape
        outputs = model(inputs)
        predicted = (outputs > 0.5).float()  # Threshold the sigmoid output to get binary predictions
        test_labels_list.extend(labels.numpy())  # Store true labels
        pred_labels_list.extend(predicted.numpy().flatten())  # Store predicted labels

# Convert lists to numpy arrays
test_labels_list = np.array(test_labels_list)
pred_labels_list = np.array(pred_labels_list)

# Calculate and print accuracy
accuracy = accuracy_score(test_labels_list, pred_labels_list)
print(f'Accuracy: {accuracy * 100:.2f}%')

# Calculate and print precision, recall, and F1-score
precision = precision_score(test_labels_list, pred_labels_list)
recall = recall_score(test_labels_list, pred_labels_list)
f1 = f1_score(test_labels_list, pred_labels_list)
print(f'Precision: {precision:.2f}')
print(f'Recall: {recall:.2f}')
print(f'F1 Score: {f1:.2f}')

# Generate confusion matrix
conf_matrix = confusion_matrix(test_labels_list, pred_labels_list)
print('Confusion Matrix:')
print(conf_matrix)

# Plot confusion matrix
plt.figure(figsize=(8, 6))
plt.imshow(conf_matrix, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.colorbar()
tick_marks = np.arange(2)
plt.xticks(tick_marks, ['Non-Extreme', 'Extreme'], rotation=45)
plt.yticks(tick_marks, ['Non-Extreme', 'Extreme'])

# Add text annotations
thresh = conf_matrix.max() / 2.
for i, j in np.ndindex(conf_matrix.shape):
    plt.text(j, i, format(conf_matrix[i, j], 'd'),
             horizontalalignment="center",
             color="white" if conf_matrix[i, j] > thresh else "black")

plt.tight_layout()
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.show()
```

### Step 7: Analysis of this task and results
- ```Task 1:``` Provide detailed interpretation of the results (write it as a part of Task 2, i.e. no need to write it separately). (Hint: If accuracy is high but other metrics are low, you have explain in detail what it means.)
- ```Task 2:``` Write a detailed analysis about this assignment, from nature of data to the final results. Write as if you are writing a small paper for publication: Introduction (Problem statement), about the data and study area, methodology, results and discussion.