# 0. Import Required Libraries

In [None]:
%matplotlib inline

import mlflow
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from sklearn.metrics import accuracy_score

import matplotlib.pyplot as plt

import math
from lib.utils import *

# 1. Define Classifier Architecture

In [None]:
class Classifier(nn.Module):
    def __init__(self, channel, in_len):
        super(Classifier, self).__init__()
        
        self.fc1_size = channel * in_len ** 2
        self.fc2_size = self.fc1_size * 3
        self.fc3_size = self.fc2_size
        self.fc4_size = self.fc1_size
        
        self.fc1 = nn.Linear(self.fc1_size, self.fc2_size)
        self.fc2 = nn.Linear(self.fc2_size, self.fc3_size)
        self.fc3 = nn.Linear(self.fc3_size, self.fc4_size)
        self.fc4 = nn.Linear(self.fc4_size, 2)
        
        self.dropout = nn.Dropout(p=0.5)
    def forward(self, x):
        # Flatten input
        x = x.view(x.shape[0], -1)
        
        x = self.dropout(F.relu(self.fc1(x)))
        
        x = self.dropout(F.relu(self.fc2(x)))
        
        x = self.dropout(F.relu(self.fc3(x)))
        
        x = F.log_softmax(self.fc4(x), dim=1)
        return x
    
    def train_network(self, trainloader, val_loader, epochs=20):
        pass
                    
    def test(self):
        pass

# 2. Start MlFlow Run

In [None]:
mlflow.set_tracking_uri("file:.\mlruns")
mlflow.start_run()

params = {}
artifacts = []
metrics = {}

# 3. Load Data

In [None]:
data, filenames = load_data(10, "./data/modis")

In [None]:
labels = [data[i][0] for i in range(len(data))]
train_data = [data[i][1:] for i in range(len(data))]

In [None]:
class LandsatDataLoader():
    def __init__(self, data, ground_truth, batch_size, shuffle=True):
        self.data = data
        self.labels = labels
        self.batch_size = batch_size
        self.shuffle = shuffle
    def __len__(self):
        return self.len
    def __iter__(self):
        self.len = 0
        if self.shuffle:
            import random
            random.shuffle(self.data)
        
        for i, image in enumerate(self.data):
            chunked_data = chunk_image(merge_dims(image))
            chunked_labels = chunk_image(self.labels[i], label=True)
                        
            dataset = list(zip(chunked_data, chunked_labels))
            dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=self.shuffle)
            self.len += len(dataloader)
            
            for batch, ground_truth in dataloader:
                yield batch, ground_truth 

# 4. Chunk Images
Each image is broken up into bx9x3x3 tensors, where b is the batch size. 

In [None]:
%%time

print("{:30} shape: (batch, channel, height, width)".format("filename"))

chunk_sum = 0

for i, image in enumerate(train_data):
    chunked_image = chunk_image(merge_dims(image))
    
    chunk_sum += chunked_image.shape[0]
        
    print("{:30} shape: {}".format(filenames[i], chunked_image.shape))
    
print("\nTotal {} x {} chunks: {}".format(chunked_image.shape[-1], chunked_image.shape[-1], chunk_sum))

In [None]:
batch_size = 2048
params["batch_size"] = batch_size

In [None]:
trainloader = LandsatDataLoader(train_data[:12], labels[:12], batch_size=batch_size)
val_loader = LandsatDataLoader(train_data[12:], labels[12:], batch_size=batch_size)

# 5. Instantiate Model and Optimizer

In [None]:
model = Classifier(9, 3)

In [None]:
#model = parallelize(model)

In [None]:
import torch.optim as optim

optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.NLLLoss()

# 6. Use GPU, if available

In [None]:
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
    
model.to(device)

# 7. Train and Validate Model

In [None]:
epochs = 30
params["epochs"] = epochs

train_losses = []
val_losses = []
val_accuracies = []

min_val_loss = float("inf")


for epoch in range(epochs):
    model.train()
    train_loss = 0
    
    for batch, ground_truth in trainloader:
        # ============================================
        #            TRAINING
        # ============================================
        batch, ground_truth = batch.to(device), ground_truth.to(device)
        output = model.forward(batch.float())
        # Clear gradients in optimizer
        optimizer.zero_grad()
        # Calculate loss
        loss = criterion(output.squeeze(), ground_truth.long())
        train_loss += loss.item()
        # Backpropagation
        loss.backward()
        # Update weights
        optimizer.step()
    else:
        with torch.no_grad():
            model.eval()
            val_loss = 0
            
            y_pred = np.array([])
            y_true = np.array([])
            
            for batch, ground_truth in val_loader:
                # ============================================
                #            VALIDATION
                # ============================================
                batch, ground_truth = batch.to(device), ground_truth.to(device)
                # forward pass
                log_probs = model.forward(batch.float())
                probs = torch.exp(log_probs)
                
                top_p, top_class = probs.topk(1, dim=1)
                y_pred = np.append(y_pred, cuda_to_numpy(top_class))
                y_true = np.append(y_true, cuda_to_numpy(ground_truth))
                
                # calculate loss
                loss = criterion(log_probs.squeeze(), ground_truth.long())
                val_loss += loss.item()

    # Print epoch summary
    t_loss_avg = train_loss / len(trainloader)
    v_loss_avg = val_loss / len(val_loader)
    accuracy = accuracy_score(y_true, y_pred)
    
    if v_loss_avg < min_val_loss:
        torch.save(model.state_dict(), "./artifacts/model.pth")
        artifacts.append("model.pth")
        
    mlflow.log_metric("train_loss", t_loss_avg)
    mlflow.log_metric("val_loss", v_loss_avg)
    mlflow.log_metric("validation_accuracy", accuracy)
    
    train_losses.append(t_loss_avg)
    val_losses.append(v_loss_avg)
    val_accuracies.append(accuracy)
    
    print('Epoch [{:5d}/{:5d}] | train loss: {:8.6f} | validation loss: {:8.6f} | validation accuracy: {:6.4f}%'.format(
                epoch+1, epochs, t_loss_avg, v_loss_avg, accuracy * 100))

# 8. Plot Learning Curve

In [None]:
plt.plot(train_losses, label="Training")
plt.plot(val_losses, label="Validation")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Learning Curve for MODIS Image Classifier")
plt.legend()

figure_name = "train_loss.png"
plt.savefig("./artifacts/" + figure_name)
artifacts.append(figure_name)

In [None]:
plt.plot(val_accuracies)
plt.xlabel("Epochs")
plt.ylabel("Accuracy(%)")
plt.title("Validation Accuracy for MODIS Image Classifier")

figure_name = "val_accuracy.png"
plt.savefig("./artifacts/" + figure_name)
artifacts.append(figure_name)

# 8. Wrap up MlFlow Run

In [None]:
for name, val in params.items():
    mlflow.log_param(name, val)

for name, val in metrics.items():
    mlflow.log_metric(name, val)
    
artifact_path = "./artifacts/"
for name in artifacts:
    mlflow.log_artifact(artifact_path + name)

mlflow.end_run()