In [None]:
import sys 
import glob 
import torch
import pandas as pd
import os 
import numpy as np 
import tensorflow as tf
import likelihood as lh

%load_ext autoreload
%autoreload 2

sys.path.insert( '../models/')
from unet_model import UNet
from tqdm.notebook import tqdm 


In [None]:
#model = UNet(3, 2)
model = UNet(4, 2, full_size=True)
#model.load_state_dict(torch.load('~/datadrive/data/archive/saved_models/finished/model8_10_ia_data.pth'))
model.load_state_dict(torch.load('../../datadrive/mixed_20epochs_11-12.pth'))

In [None]:
os.makedirs('inference_results')

In [None]:
os.path.exists('inference_results')

In [None]:
truth_df = pd.read_csv('~/datadrive/data/ground_truth_no_dups.csv')
truth_df['dir'] = [os.path.join(
    "../../datadrive/data/processed/in_2019_batched", v
) for v in truth_df.location_name.values]

truth_df.head()

In [None]:
dim = 200 
out_dir = 'inference_results'

for i, row in tqdm(truth_df.iterrows()): 
    
    if row.label == 2 or \
        os.path.exists(os.path.join(out_dir, f'{row.location_name}.p')): 
        continue 
    
    ims = sorted(glob.glob(os.path.join(row.dir, '*.tif')))
    
    # Compute image array 
    img_array = np.empty((len(ims), 4, dim, dim))
    for i, im in enumerate(ims): 
    
        with rasterio.open(im) as src:
            b, g, r, n = src.read()

        stack = (r,g,b,n)
        rgb = np.stack(stack, axis=0)
        np_image = ((rgb/rgb.max())*255).astype(np.uint8)
        
        img_array[i, :, :, :] = np_image
        
    # Apply model to array         
    output = model(torch.Tensor(img_array)) # shape (len(ims), )
    print(output.shape)
    output = -output.detach().numpy().reshape(output.shape[0], 2,dim,dim)[:,1,:,:]
    probs = 1 / (1 + np.exp(output))    
    print(probs.shape)
    
    # Save output
    with open(os.path.join(out_dir, f'{row.location_name}.p'), 'wb') as f: 
        pickle.dump(probs, f) 


In [None]:
os.listdir(out_dir)

In [None]:
def add_noise(arr, perc_noise = 0.01): 
    
    flat = arr.flatten()
    n = len(arr.flatten())
    n_pixels = round(perc_noise * n)
    inds = np.random.choice(n, n_pixels, replace=False)
    
    flat[inds] = 1 - flat[inds] # Reverse probability 
    return flat.reshape(arr.shape)


In [None]:
out_dir = 'inference_results'

In [None]:
## Apply likelihood method 

#perc_noise = 0.02
results = {}

for i, loc in tqdm(enumerate(os.listdir(out_dir))):
    
#     loc_name = loc.split('.p')[0]
#     if loc_name in results.keys():
#         continue 

    if loc != 'loc_0901.p':
        continue 
    
    with open(os.path.join(out_dir, loc), 'rb') as f: 
        probs = pickle.load(f)
        
    seed = int(loc.split('_')[-1].split('.p')[0])
        
    #probs = add_noise(probs, perc_noise)    
    probs = tf.convert_to_tensor(probs, dtype=float)
    probs = tf.reshape(probs, [probs.shape[0], probs.shape[1], probs.shape[2], 1])
    smth = lh.tf_smooth(probs, 'box', size=3)
    ts, null, pre, post, t, t_weight = lh.change_ts(smth, learning_rate=0.5, num_steps=4000, return_model=True)
    #ts = lh.change_ts(smth, learning_rate=0.5, num_steps=5000, return_model=False, seed=seed)
#    print(ts[-1], row.label)
    
    
    #results[loc_name] = ts 
    
#     if i % 20 == 0: 
#         with open('results_noise_0.02.p', 'wb') as f: 
#             pickle.dump(results, f)

        
    

In [None]:
import matplotlib.pyplot as plt

In [None]:
os.makedirs('mle_ims')

In [None]:
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (13,8)

In [None]:
for i, im in enumerate([probs[0], probs[-2], pre, post]):

    plt.imshow(im)
    ax = plt.gca()
    ax.axes.xaxis.set_visible(False)
    ax.axes.yaxis.set_visible(False)
    plt.savefig(f'mle_ims/3mle_{i}.png')
    plt.show()
    
    
    


In [None]:
import scipy.stats as st


In [None]:
fig, ax = plt.subplots(2,2, figsize=(15,15))
ax = ax.ravel()

ax[0].imshow(probs[0])

ax[1].imshow(probs[-1])

ax[2].imshow(st.norm.ppf(pre))

ax[3].imshow(st.norm.ppf(post))

for i in range(4):
    
    ax[i].axes.xaxis.set_visible(False)
    ax[i].axes.yaxis.set_visible(False)
    






In [None]:
with open('results_noise_0.02_complete.p', 'wb') as f: 
    pickle.dump(results, f)