In [64]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import TensorDataset, DataLoader, random_split
from utils import read_data
from models import NextFrameModel
import glob

# Create a Torch Dataset

In order to input the generated intensity data into DL models, 
we must first convert the data into a PyTorch Tensor and make a 
PyTorch Dataset. 



In [35]:
_files = glob.glob("code/Intensity*")

In [18]:
z_slices = []
for filename in _files:
    z, intensity = read_data(filename)
    z_slices.append((z, intensity))

In [19]:
z_slices = sorted(z_slices, key=lambda x:x[0])

In [20]:
data_frame = []
predict_frame = []
data_frame_z = []
predict_frame_z = []
for i in range(0, len(z_slices), 2):
    data_frame.append(z_slices[i][1])
    predict_frame.append(z_slices[i+1][1])
    data_frame_z.append(z_slices[i][0])
    predict_frame_z.append(z_slices[i+1][0])

### Converting to Torch TensorDataset

Once the dataset is read, we convert the native Python Array to a Numpy Array and then convert that to a Torch Tensor

In [25]:
data_frame = torch.from_numpy(np.array(data_frame)).float()
predict_frame = torch.from_numpy(np.array(predict_frame)).float()

data_frame_z = torch.from_numpy(np.array(data_frame_z)).float()
predict_frame_z = torch.from_numpy(np.array(predict_frame_z)).float()

In [26]:
pytorch_dataset = TensorDataset(data_frame,
                                predict_frame,
                                data_frame_z,
                                predict_frame_z)

In [37]:
training_set_size = int(.7 * len(pytorch_dataset))
test_set_size = len(pytorch_dataset) - training_set_size
training_set, test_set = random_split(pytorch_dataset,[training_set_size, test_set_size])

## Initialize the Model, Data Loaders, Loss, and Optimizers

In [65]:
training_batch_size = 32
test_batch_size = len(test_set)
model = NextFrameModel().cuda() 
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-8)

training_loader = DataLoader(training_set,
                             batch_size=training_batch_size,
                             shuffle=True,
                             num_workers=10)
test_loader = DataLoader(test_set,
                         batch_size=test_batch_size,
                         shuffle=False,
                         num_workers=10)

### Baseline Accuracy

We record the test accuracy without any training to use as a baseline

In [66]:
model.eval() ## Set model to evaluation mode
test_mse = 0
with torch.no_grad():  ## Turn of gradient calculations
    for cur_frame, next_frame, cur_z, next_z in test_loader:
        cur_frame = cur_frame.cuda()
        next_frame = next_frame.cuda()
        
        model_prediction = model(cur_frame.unsqueeze(dim=1))
        test_mse += criterion(next_frame.unsqueeze(dim=1), model_prediction).item()
print(f"Baseline test error is {test_mse}")

Baseline test error is 2.2385545286595975e+29


### Train the Model

In [67]:
NUM_EPOCHS = 30

for epoch in range(NUM_EPOCHS):
    model.train()
    epoch_loss = 0
    for cur_frame, next_frame, cur_z, next_z in training_loader:
        cur_frame = cur_frame.cuda()
        next_frame = next_frame.cuda()
        model_prediction = model(cur_frame.unsqueeze(dim=1))
    
        loss = criterion(next_frame.unsqueeze(dim=1), model_prediction)
        
        optimizer.zero_grad() 
        loss.backward()
        optimizer.step() 
        
        epoch_loss += loss.item()
    print(f"Training error at epoch {epoch} : {epoch_loss / len(training_loader)}")

Training error at epoch 0 : 3.1332487029415953e+28
Training error at epoch 1 : 3.1419479513284325e+28
Training error at epoch 2 : 3.1262960611515016e+28
Training error at epoch 3 : 3.1354361621260415e+28
Training error at epoch 4 : 3.116567966520263e+28
Training error at epoch 5 : 3.1333478332846816e+28
Training error at epoch 6 : 3.1526857862963436e+28
Training error at epoch 7 : 3.096939076380202e+28
Training error at epoch 8 : 3.1323469867382185e+28
Training error at epoch 9 : 3.1333478332846816e+28
Training error at epoch 10 : 3.2030134232611764e+28
Training error at epoch 11 : 3.130972227148947e+28
Training error at epoch 12 : 3.156054230632048e+28
Training error at epoch 13 : 3.1334088698714727e+28
Training error at epoch 14 : 3.1219451874986174e+28
Training error at epoch 15 : 3.130972128766312e+28
Training error at epoch 16 : 3.125615036875091e+28
Training error at epoch 17 : 3.130070550298624e+28
Training error at epoch 18 : 3.1306715501396773e+28
Training error at epoch 19 : 