# "Standard Candle" 11: Start Getting Real

Combine the work in experiment 10 with the older work in experiment 6 to have a single solid start, then bring in the basic dataloader to start getting real SDO images.

Note that this notebook should be run from the `notebooks/` subdirectory.

In [None]:
# NOTE: Change this to what the notebook name is for each experiment to ensure
# training results are saved into the right sub-directory.
notebook_name = '11_getting_real'

from collections import namedtuple, defaultdict
import math
import random
import os
import shutil
import pdb
from functools import reduce
import operator

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim

import pandas

import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline

from PIL import Image

In [None]:
# Path to data and training results.
root_path = '../..' # Relative to: notebooks/01b_simple_toy
results_path = os.path.join(root_path, 'training_results', notebook_name)
model_path = os.path.join(results_path, 'model.pth')
optimizer_path = os.path.join(results_path, 'optimizer.pth')

for path in [results_path]:
  if not os.path.exists(path):
    print('{} does not exist; creating directory...'.format(os.path.abspath(path)))
    os.makedirs(path)

num_epochs = 5
batch_size_train = 64
batch_size_test = 1000
log_interval = 10
height = 28
width = 28
num_channels = 7

In [None]:
def init_gpu(cuda_device=0):
  """ Use the GPU. """
  torch.backends.cudnn.enabled = True
  if not torch.cuda.is_available():
    raise RuntimeError("CUDA not available! Unable to continue")
  # Force ourselves to use only one GPU.
  device = torch.device("cuda:{}".format(cuda_device))
  print("Using device {} for training, current device: {}, total devices: {}".format(
    device, torch.cuda.current_device(), torch.cuda.device_count()))
  return device

def set_seed(random_seed=1, deterministic_cuda=True):
  """ Force runs to be deterministic and reproducible. """
  np.random.seed(random_seed)
  random.seed(random_seed)
  torch.manual_seed(random_seed)
  torch.cuda.manual_seed(random_seed)
  os.environ['PYTHONHASHSEED'] = str(random_seed)

  # Note: this can have a performance hit.
  if deterministic_cuda:
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
device = init_gpu(cuda_device=1)
set_seed()

In [None]:
class Net(nn.Module):
    def __init__(self, input_shape, output_dim):
        super().__init__()
        if (len(input_shape) != 3):
            raise ValueError('Expecting an input_shape representing dimensions CxHxW')
        self._input_channels = input_shape[0]
        print('input_channels: {}'.format(self._input_channels))
        self._conv2d1 = nn.Conv2d(in_channels=self._input_channels, out_channels=64, kernel_size=3)
        self._conv2d2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3)
        self._cnn_output_dim = self._cnn(torch.zeros(input_shape).unsqueeze(0)).nelement()
        print('cnn_output_dim: {}'.format(self._cnn_output_dim))
        self._fc1 = nn.Linear(self._cnn_output_dim, 256)
        self._fc2 = nn.Linear(256, output_dim)
        
    def _cnn(self, x):
        x = self._conv2d1(x)
        x = torch.relu(x)
        x = nn.MaxPool2d(kernel_size=3)(x)
        x = self._conv2d2(x)
        x = nn.MaxPool2d(kernel_size=3)(x)
        return x
    
    def forward(self, x):
        batch_dim = x.shape[0]
        x = self._cnn(x).view(batch_dim, -1)
        x = self._fc1(x)
        x = torch.relu(x)
        x = self._fc2(x)
        x = torch.sigmoid(x)
        return x

In [None]:
class SyntheticSDODataset(torch.utils.data.Dataset):
  def __init__(self, num_channels, height, width, size, length):
    self._num_channels = num_channels
    self._height = height
    self._width = width
    self._size = size
    self._length = length
    
    # TODO: Compute mean and std across dataset, and normalize them.
    
  def __getitem__(self, idx):
    sun = self._draw_sun(self._num_channels, self._size, self._height, self._width)
    dimmed_sun = sun.clone().to(device)
    dim_factor = torch.rand(num_channels).to(device)    
    for c in range(num_channels):
      dimmed_sun[c] *= dim_factor[c]
      
    return dimmed_sun, dim_factor, sun
    
  def __len__(self):
    return self._length
  
  def _draw_sun(self, num_channels, size, height, width):
    """ Draw a synthetic sun. """
    xx, yy = np.meshgrid(np.arange(size)-(size-1)/2.,
                         np.arange(size)-(size-1)/2)
    r = np.sqrt(xx*xx+yy*yy) 
    R = np.random.rand(1)*(size/3)
    channels = []
    for c in range(num_channels):
        I_c = R**(c/10.)
        channels.append(np.exp(-(r*r)/(R*R))*I_c)

    image = [torch.from_numpy(channel).float().to(device) for channel in channels]
    image = torch.cat(image).view(num_channels, height, width)

    return image

In [None]:
train_dataset = SyntheticSDODataset(num_channels, height, width, size=height, length=10000)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size_train,
                                           shuffle=True)

test_dataset = SyntheticSDODataset(num_channels, height, width, size=height, length=1000)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size_test,
                                           shuffle=True)

In [None]:
def show_sample_image(loader):
  print('\nUndimmed channels for a single original sun:\n')
  _, _, orig_sun = loader.dataset[0]
  sun_numpy = orig_sun.cpu().numpy()
  for channel in sun_numpy:
    plt.imshow(channel, norm=None, cmap='hot', vmin=sun_numpy.min(), vmax=sun_numpy.max())
    plt.show()

def print_details(orig_data, output, dimmed_data, dim_factors, train):
  print('\n\nDetails with sample from final batch:')
  data_min, data_max = torch.min(orig_data), torch.max(orig_data)
  sample = orig_data[0].cpu().numpy()
  sample_dimmed = dimmed_data[0].cpu().numpy()

  for i, (channel, channel_dimmed) in enumerate(zip(sample, sample_dimmed)):
    fig = plt.figure()
    ax1 = fig.add_subplot(1, 3, 1)
    ax1.imshow(channel, norm=None, cmap='hot', vmin=data_min, vmax=data_max)
    ax2 = fig.add_subplot(1, 3, 2)
    ax2.imshow(channel_dimmed, norm=None, cmap='hot', vmin=data_min, vmax=data_max)
    ax3 = fig.add_subplot(1, 3, 3)
    ax3.imshow(channel_dimmed / float(output[0, i]), norm=None, cmap='hot', vmin=data_min, vmax=data_max)
    print('\nChannel: {} (left: original, middle: dimmed, right: undimmed)\nDimming (true): {}, dimming (predicted): {}'.format(
      i, dim_factors[0, i], output[0, i]))
    plt.show()
  dim_factors_numpy = dim_factors[0].view(-1).cpu().numpy()
  plt.plot(dim_factors_numpy, label='Dimming factors (true)')
  output_numpy = output[0].detach().view(-1).cpu().numpy()
  plt.plot(output_numpy, label='Dimming factors (predicted)')
  title = 'training dimming factors' if train else 'testing dimming factors'
  plt.title(title)
  plt.legend()
  plt.show()

def generate_train_accuracy_stats(correct, output, targets, num_column_labels, num_subsample):
  # TODO: For efficiency reasons, convert all of this to torch rather than numpy operations
  # so that we can do all this work on the GPU.
  preds = output.detach().cpu().numpy()
  targets = targets.detach().cpu().numpy()

  # If a channel brightness prediction is within this percentage of the ground truth then we
  # consider that prediction correct.
  pct_close = 0.15

  # Means across each row that collapses each batch entries predictions,
  # the ground truth brightness, and the delta btw ground truth and prediction.
  mean_per_channel_prediction = preds.mean(axis=0)
  mean_per_channel_gt = targets.mean(axis=0)
  per_channel_diff = np.abs(targets - preds)
  mean_per_channel_diff = per_channel_diff.mean(axis=0)

  # Various stats around channel correctness.
  per_channel_correct = per_channel_diff <= np.abs(pct_close * targets)
  # TODO: We can probably get rid of both of these np.where() conversions to 1/0s and just use the boolean
  # array itself for the sum.
  percentage_channels_correct = np.sum(np.where(per_channel_correct, 1, 0), axis=1, keepdims=True,
                                       dtype=np.int)
  correct_per_channel = np.sum(np.where(per_channel_correct, 1, 0), axis=1, keepdims=True,
                               dtype=np.int)
  pct_correct_per_channel = correct_per_channel / num_channels

  # Which batch results have _all_ of their channel predictions fully correct?
  # TODO: We can probably get rid of the np.where() conversion to 1/0s and just use the boolean
  # array itself for the sum.
  num_fully_correct_all_channels = np.where(pct_correct_per_channel == 1.0, 1, 0).sum()
  pct_fully_correct_all_channels = num_fully_correct_all_channels / batch_size_test
  correct += num_fully_correct_all_channels
  
  pretty_results = np.zeros((min(batch_size_test, len(preds)), num_column_labels), dtype=np.float32)

  # The mean channel prediction across each row of the batch results.
  pretty_results[:, 0] = np.round(preds.mean(axis=1), decimals=2)

  # The mean channel ground truth across each row of the batch results.
  pretty_results[:, 1] = np.round(targets.mean(axis=1), decimals=2)

  # The mean difference btw prediction and grouth truth across each row of the batch results.
  pretty_results[:, 2] = np.round(np.abs(targets - preds).mean(axis=1), decimals=2)

  # Percentage correct across all the channels for a given batch row?
  pretty_results[:, 3] = np.round(pct_correct_per_channel * 100.0, decimals=0)[:, 0].astype(np.int)
  
  # Randomly sub-sample some of the results since 100s or 1000s are too much to display.
  lookup_idxs = np.random.choice(len(pretty_results), size=(num_subsample,))
  pretty_results = pretty_results[lookup_idxs]

  return correct, pretty_results

In [None]:
show_sample_image(train_loader)

model = Net(input_shape=[num_channels, height, width], output_dim=num_channels)
model.cuda(device)
optimizer = torch.optim.Adam(model.parameters())
  
def train(epoch):
  print("\n\n===================================\n\n")
  print("\n\nTraining epoch {}\n".format(epoch))
  model.train()
  losses = []
  for batch_idx, (dimmed_data, dim_factors, orig_data) in enumerate(train_loader):
    optimizer.zero_grad()
    output = model(dimmed_data)
    loss = nn.MSELoss()(output, dim_factors)
    loss.backward()
    optimizer.step()
    losses.append(float(loss))

    if batch_idx % log_interval == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(orig_data), len(train_loader.dataset),
        100.0 * (batch_idx / len(train_loader)), float(loss)))
      # TODO: Save the model.
      # torch.save(model.state_dict(), model_path)
      # torch.save(optimizer.state_dict(), optimizer_path)

  # Print extra debug output on the final batch.
  print_details(orig_data, output, dimmed_data, dim_factors, train=True)

  plt.plot(losses, label='training loss')
  plt.title('training loss')
  plt.show()
  
  print('\nAt end of train epoch {}, loss min: {}, max: {}, mean: {}'.format(epoch,
    min(losses), max(losses), np.mean(losses)))
  
  return np.mean(losses)

def test(epoch):
  print("\n\nTesting epoch {}".format(epoch))
  with torch.no_grad():
    model.eval()
    losses = []
    correct = 0
    current_batch = 0
    num_subsample = 3 # Per batch, how many results to subsample to print out for debugging.
    column_labels = ['Pred', 'GT', 'Delta', 'Pct Correct']
    pretty_results = np.zeros((int(math.ceil(len(test_loader.dataset) / batch_size_test) * num_subsample),
                               len(column_labels)), dtype=np.float32)
    for batch_idx, (dimmed_data, dim_factors, orig_data) in enumerate(test_loader):
      current_batch += 1
      dimmed_data = dimmed_data.to(device)
      dim_factors = dim_factors.to(device)
      output = model(dimmed_data)
      output = output.to(device)
      loss = nn.MSELoss()(output, dim_factors)
      losses.append(float(loss))

      correct, pretty_print_subset = generate_train_accuracy_stats(
        correct, output, dim_factors, len(column_labels), num_subsample)
      current_batch_idx = current_batch - 1
      pretty_results[current_batch_idx*num_subsample:current_batch_idx*num_subsample+num_subsample] = pretty_print_subset

    # Print extra debug output on the final batch.
    print_details(orig_data, output, dimmed_data, dim_factors, train=False)
    
    print("\n\nRandom sample of mean predictions across channels for test set, "
          "where each row is a test sample in the training batch:\n")
    df = pandas.DataFrame(pretty_results, columns=column_labels)
    print(df.to_string(index=False))
    
    print('\n\nEpoch {}, test set: avg. loss: {:.8f}, Accuracy all channels correct: {}/{} ({:.0f}%)'.format(
          epoch, np.mean(losses), correct, len(test_loader.dataset),
          100.0 * (correct / len(test_loader.dataset))))

    return np.mean(losses)

In [None]:
train_losses = []
test_losses = []
for epoch in range(1, num_epochs + 1):
  train_losses = train(epoch)
  test_losses = test(epoch)
  
  plt.plot(test_losses, label='testing loss')
  plt.title('testing loss')
  plt.show()
  
print('\n\nFinal mean training loss after {} epochs: {}'.format(
  num_epochs, train_losses.mean()))
print('Final mean testing loss after {} epochs: {}'.format(
  num_epochs, test_losses.mean()))