In [None]:
import numpy as np
import tensorflow as tf
import pickle
import glob
import pandas as pd
import os 
import sys

sys.path.append('../model')
import likelihood as lh
import ddm
from tqdm.notebook import tqdm 

## Training

We train one model per week, from week 11 to week 52, using the previous ten weeks as training. We train one model per band. We train on non-anomalous image, so first we remove all known anomalies from the training set. 

In [None]:
with open('../../time_periods.p', 'rb') as f: 
    periods = pickle.load(f)

In [None]:
truth_df = pd.read_csv('../../truth_cleaned.csv')
positive_locs = truth_df[truth_df.label == True].location_name.to_numpy()

for loc in positive_locs: 
    periods.pop(loc, None)

In [None]:
def get_band_images(band, inds, week, basis_length):
    """Get training and test images from periods dictionary, ending at 
    given week, in given band, and extending basis_length weeks backwards. 
    
    Args: 
        band (float): either 'r', 'g', or 'b' 
        inds (ndarray): indices of images in dictionary 
        week (int): week of last (test) image 
        basis_length (int): number of previous weeks to get. 
        
    Returns: 
        basis (tf.tensor): basis_length number of images for each 
            index prior to week. 
            shape = (len(inds), basis_length, dim_x, dim_y). 
        test (tf.tensor): image at given week for each index. 
            shape = (len(inds), 1, dim_x, dim_y, 1)
    
    """
    
    imgs = np.array(list(periods.keys()))[inds]
    basis = []
    test = []
    
    for im in imgs: 
        
        basis.append(
            [periods[im][f'week_{w}'][band] 
             for w in range(week-basis_length,week)]
        )
        test.append(periods[im][f"week_{week}"][band])
        
    basis = tf.convert_to_tensor(np.array(basis), dtype=float)
    basis = tf.reshape(
        basis, [basis.shape[0], basis.shape[1], 
                basis.shape[2], basis.shape[3], 1]
    )
    
    test = tf.convert_to_tensor(np.array(test), dtype=float)
    test = tf.reshape(
        test, [test.shape[0], 1, test.shape[1], test.shape[2], 1]
    )
    
    return basis, test

In [None]:
# Train

basis_length = 10 
weeks = range(basis_length + 1, 53)

n_train_samples = 500 # Number of images to use for training 
n_locations = len(periods)

for w in tqdm(weeks): 
    
    # skip if model already exists 
    if os.path.exists(f'../../models/model_week_{w}'):
        continue
    
    # Model for this week 
    w_model = {}
    
    # Grab random sample of (negative) images 
    inds = np.random.choice(n_locations, n_train_samples, replace=False)
    
    for band in ['r', 'g', 'b']:
        
        X, y = get_band_images(band, inds, w, basis_length)
        w_model[band] = ddm.fit_observation(
            X, y, num_steps=2000, learning_rate=0.001,
            reg=0.01, normalization='none'
        )
        
    with open(f'../../models/model_week_{w}', 'wb') as f: 
        pickle.dump(w_model, f)
        


## Predict

In [None]:
with open('../../time_periods.p', 'rb') as f: 
    periods = pickle.load(f)

In [None]:
# Load all models 

models = {}
model_dir = '../../models'

for m in os.listdir(model_dir):
    
    week = m.split('_')[-1]
    with open(os.path.join(model_dir, m), 'rb') as f: 
        models[int(week)] = pickle.load(f)


In [None]:
def get_loc_images(band, loc, week, basis_length):
    """Get images from specific key in periods"""
    
    basis = [periods[loc][f'week_{w}'][band] 
         for w in range(week-basis_length,week)]
        
    test = periods[loc][f"week_{week}"][band]
        
    basis = tf.convert_to_tensor(np.array(basis), dtype=float)
    basis = tf.reshape(
        basis, [1, basis.shape[0], 
                basis.shape[1], basis.shape[2], 1]
    )
    
    test = tf.convert_to_tensor(np.array(test), dtype=float)
    test = tf.reshape(
        test, [1, 1, test.shape[0], test.shape[1], 1]
    )
    
    return basis, test

In [None]:
results = {}

for loc in tqdm(periods.keys()):
    
    results[loc] = {}
    
    for week in range(11, 53):
        
        model_w = models[week]
        results[loc][week] = {}
        
        for band in ['r', 'g', 'b']:
            
            model = model_w[band]
            X, y = get_loc_images(band, loc, week, basis_length)      
            hot_score = ddm.hot_detect(model['gamma'], basis=X, test=y, rmse=model['rmse'],
                               normalization='none', mean=model['mean'], std=model['std'],
                               reduce=True)

            results[loc][week][band] = hot_score.numpy()[0]
    

In [None]:
with open('../../ddm_results.p', 'wb') as f: 
    pickle.dump(results, f)