In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from matplotlib import cm
from torch.utils.data import Dataset, DataLoader
from torch.tensor import Tensor
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module

import warnings

from glob import glob

from PIL import Image


from hofer import UpperDiagonalThresholdedLogTransform, prepare_batch, SLayer

# Hofer Data Prep

# Prepping dim0 diagrams for vectorization

In [2]:
dim0_files = glob('./barcodes/dim0/*.npy')
trans = UpperDiagonalThresholdedLogTransform(.1)

In [3]:
transformed_and_tensored = [trans(torch.tensor(np.load(d))) for d in dim0_files]

In [4]:
X, dummy, max_points, batch = prepare_batch(transformed_and_tensored)

# Saving dim0 diamgrams

In [27]:
for ix, d in enumerate(dim0_files):
    np.save(d.replace('/dim0/', '/dim0_vector/'), X[ix, :, :])

In [44]:
np.save('./barcodes/dim0_vector/dummy', dummy)

*** 

# Prepping dim1 diagrams for vectorization

In [45]:
dim1_files = glob('./barcodes/dim1/*.npy')
trans = UpperDiagonalThresholdedLogTransform(.1)

In [46]:
transformed_and_tensored_dim1 = [trans(torch.tensor(np.load(d))) for d in dim1_files]

In [47]:
X_dim1, dummy_dim1, max_point_dim1, batch_dim1 = prepare_batch(transformed_and_tensored_dim1)

# Saving dim0 diamgrams

In [31]:
for ix, d in enumerate(dim1_files):
    np.save(d.replace('/dim1/', '/dim1_vector/'), X_dim1[ix, :, :])

In [48]:
np.save('./barcodes/dim1_vector/dummy', dummy_dim1)

In [34]:
np.load('./barcodes/dim1_vector/1989_02_SIN_data_dim1.npy').shape

(1175, 2)

# Confirming Data Pipeline is working

In [14]:
torch.unsqueeze(torch.tensor([1,2]), dim=0).shape

torch.Size([1, 2])

In [124]:
def clean_data(data):
    data[data == -1] = 0
    data[143:167, 223:247][data[143:167, 223:247] == 0] = 1
    return data 

def resize_data(data):
    im = Image.fromarray(np.uint8(cm.gist_earth(data, alpha=None)*255)[:, :, :3])
    resized_data = np.array(im.resize((112, 76)).convert('L'))
    return resized_data/resized_data.max()

# class SeaIceDataset(Dataset):
    
#     def __init__(self, seq_len, data_folder='./data/*.pkl', return_dims=False, dim0_folder='dim0_vector', dim1_folder='dim1_vector'):
#         self.seq_len = seq_len
#         self.data_files = glob(data_folder)
#         self.return_dims = return_dims
#         self.dim0_folder = dim0_folder
#         self.dim1_folder = dim1_folder
        
#     def __len__(self):
#         return len(self.data_files) - (self.seq_len + 1)
    
#     def __getitem__(self, ix):
#         X = np.array([resize_data(clean_data(np.array(np.load(d)))) for d in self.data_files[ix:ix+self.seq_len]], dtype=np.float32)
#         y = resize_data(clean_data(np.load(self.data_files[ix+self.seq_len+1]))).flatten().astype(np.float32)
        
#         if self.return_dims:
#             dim0 = np.array([np.load('./barcodes/{}/{}'.format(self.dim0_folder, d.split('/')[-1].split('.')[0] + '_dim0.npy')) for d in self.data_files[ix:ix+self.seq_len]], dtype=np.float32)
#             dim1 = np.array([np.load('./barcodes/{}/{}'.format(self.dim1_folder, d.split('/')[-1].split('.')[0] + '_dim1.npy')) for d in self.data_files[ix:ix+self.seq_len]], dtype=np.float32)
#             return X, dim0, dim1, y
#         else:
#             return X, y


class SeaIceDataset(Dataset):
    
    def __init__(self, seq_len, data_folder='./data/*.pkl', return_dims=False):
        self.seq_len = seq_len
        self.data_files = glob(data_folder)
        self.return_dims = return_dims
        
    def __len__(self):
        return len(self.data_files) - (self.seq_len + 1)
    
    def __getitem__(self, ix):
        X = np.array([resize_data(clean_data(np.array(np.load(d)))) for d in self.data_files[ix:ix+self.seq_len]], dtype=np.float32)
        y = resize_data(clean_data(np.load(self.data_files[ix+self.seq_len+1]))).flatten().astype(np.float32)
        
        if self.return_dims:
            trans = UpperDiagonalThresholdedLogTransform(.1)
            # Read in the dim0, and dim1 data.
            dim0 = [trans(torch.tensor(np.load('./barcodes/dim0/{}'.format(d.split('\\')[-1].split('.')[0] + '_dim0.npy')))) for d in self.data_files[ix:ix+self.seq_len]]
            dim1 = [trans(torch.tensor(np.load('./barcodes/dim1/{}'.format(d.split('\\')[-1].split('.')[0] + '_dim1.npy')))) for d in self.data_files[ix:ix+self.seq_len]]
            return X, dim0, dim1, y
        else:
            return X, y

In [125]:
seq_length = 3; batch_size = 2
si_dataset = SeaIceDataset(seq_length, return_dims=False)
train_loader = DataLoader(si_dataset, shuffle=True, batch_size=batch_size, num_workers=0, drop_last=True)#, collate_fn=collation_station)

In [126]:
for X, y in train_loader:
    print(X.shape)
    print(y.shape)
    break

torch.Size([2, 3, 76, 112])
torch.Size([2, 8512])


In [None]:
(2, 3, X, 2)

In [107]:
YEET[0][1]

[tensor([[ 0.7071,  0.7071],
         [ 0.7071,  0.7071],
         [ 0.7071,  0.7071],
         ...,
         [10.6066, 10.6066],
         [14.3179, 14.3179],
         [70.7107, 70.7107]]), tensor([[ 0.7071,  0.7071],
         [ 0.7071,  0.7071],
         [ 0.7071,  0.7071],
         ...,
         [13.9463, 13.9463],
         [14.5602, 14.5602],
         [70.7107, 70.7107]]), tensor([[ 0.7071,  0.7071],
         [ 0.7071,  0.7071],
         [ 0.7071,  0.7071],
         ...,
         [ 7.2801,  7.2801],
         [10.5119, 10.5119],
         [70.7107, 70.7107]])]

In [110]:
[1,2,3,4,4]*

SyntaxError: invalid syntax (<ipython-input-110-aef61829cf62>, line 1)

In [132]:
def collation_station(batch):
    X = torch.stack([torch.from_numpy(b[0]) for b in batch], 0)
    y = torch.stack([torch.from_numpy(b[-1]) for b in batch], 0)
    dim0_batch = []
    dim1_batch = []
    for b in batch:
        dim0_batch += b[1]
        dim1_batch += b[2]
    dim0 = prepare_batch(dim0_batch)
    dim1 = prepare_batch(dim1_batch)
    return X, dim0, dim1, y

In [133]:
seq_length = 3; batch_size = 2
si_dataset = SeaIceDataset(seq_length, return_dims=True)
train_loader = DataLoader(si_dataset, shuffle=True, batch_size=batch_size, num_workers=0, drop_last=True, collate_fn=collation_station)

In [141]:
dim0[3]

6

In [134]:
for X, dim0, dim1, y in train_loader:
    break
    