In [7]:
import numpy as np
import torch
from PIL import Image
import torchvision.transforms as transforms
import os


dw, dh = Image.open('raw/train/env_0.png').size
dx, dy = 1, 1

env_tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x[:3, :, :]),  # Keep only the first 3 channels (ignore alpha)
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize to range [-1,1]
])

# Read env_i.png and traj_i.npy from raw/train and raw/val to normalize & rearrange raw data in tensors
def process(in_folder='raw/train'):
    data = []
    # iterate over all files in the in_folder
    for filename in os.listdir(in_folder):
        if filename.endswith('.png') and 'soln' not in filename:
            img = env_tf(Image.open(os.path.join(in_folder, filename)))

            traj_file = filename.replace('env', 'traj').replace('.png', '.npy')
            traj = np.load(os.path.join(in_folder, traj_file)) / 30.0  # Normalize to range [0,1]
            traj = torch.tensor(traj, dtype=torch.float32).unsqueeze(-1)  # dx, dy = 1, 1 so unqueeze gives x: (1, t_steps, 1) y: (1, t_steps, 1)
            
            data.append({'env': img, 'x': traj[0], 'y': traj[1]})
    return data
            

In [8]:
train_data = process('raw/train')
val_data = process('raw/val')

# Save the processed data
torch.save(train_data, 'processed/train.pt')
torch.save(val_data, 'processed/val.pt')