In [10]:
import pandas as pd
import numpy as np
import einops
from matplotlib import pyplot as plt
import scipy.io
from utils import assert_equal
import random
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import os
import locusts


In [27]:
# import all of the locust data into one big data
from locusts import import_data

data = {}
# for each file in Locusts/Data/Tracking, import the data
for filename in os.listdir("Locusts/Data/Tracking"):
    if filename.endswith(".mat"):
        # convert e.g. 15UE20200509_annotation.mat to 15UE20200509
        # to do so, get rid of '_annotation.mat' at the end
        f = filename[:-len('_annotation.mat')]
        arr, info = import_data(f)
        data[f] = arr, info


In [72]:
# input data is (T, N, D)
# we add food particles to make N+2
# we also add (1) a binary is_food indicator,
# (2) if the food is high quality or not
# (3) the radius of the food location
# so the final dim is D + 3
def process_data(arr, info, n_features=4):
    T, N, D = arr.shape
    max_N = 30
    data = np.zeros((T, max_N+2, D + 3))
    data[:, 2:N+2, :D] = arr
    data[:, 0, :D] = info['posA']
    data[:, 1, :D] = info['posB']
    data[:, 0:, D] = 1
    data[:, 1, D+1] = info['isA_HQ']
    data[:, 1, D+1] = info['isB_HQ']
    data[:, 0, D+2] = info['radA']
    data[:, 1, D+2] = info['radB']
    return data




In [73]:
data2 = {k: process_data(arr, info) for k, (arr, info) in data.items()}


In [75]:
# create the whole dataset
data_list = torch.stack([torch.from_numpy(v) for v in data2.values()])


In [83]:
data_list = data_list.to(torch.float32)


In [84]:
data_list.shape # (N_videos, T, N_locusts, D)

# the transformer takes in a timestep and predicts accelerations.
# in transformer talk, now we have [B, N, D]
sample_batch = data_list[0, :100, :, :]



In [79]:
D = sample_batch.shape[-1]
encoder_layer = nn.TransformerEncoderLayer(d_model=D, nhead=D)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3)



In [85]:
sample_batch.dtype


torch.float32

In [86]:
transformer_encoder(sample_batch).shape


torch.Size([100, 32, 5])