In [1]:
import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

In [2]:
import torch
import torch.nn as nn

from dataset import create_wall_dataloader

In [3]:
train_loader = create_wall_dataloader("/scratch/DL25SP/train")

In [4]:
for data in train_loader:
    init_state = data.states[:,0,:,:,:].unsqueeze(dim=1)
    init_state_agent = init_state[:, :, 0, :, :].unsqueeze(dim=2)
    init_state_env = init_state[:, :, 1, :, :].unsqueeze(dim=2)
    
    later_state = data.states[:,1:,:,:,:]
    later_state_agent = later_state[:, :, 0, :, :].unsqueeze(dim=2)
    later_state_env = later_state[:, :, 1, :, :].unsqueeze(dim=2)
    
    print(init_state.shape)
    print(init_state_agent.shape)
    print(init_state_env.shape)
    
    print(later_state.shape)
    print(later_state_agent.shape)
    print(later_state_env.shape)

    print(data.actions.shape)
    
    break

  states = torch.from_numpy(self.states[i]).float().to(self.device)


torch.Size([64, 1, 2, 65, 65])
torch.Size([64, 1, 1, 65, 65])
torch.Size([64, 1, 1, 65, 65])
torch.Size([64, 16, 2, 65, 65])
torch.Size([64, 16, 1, 65, 65])
torch.Size([64, 16, 1, 65, 65])
torch.Size([64, 16, 2])


In [5]:
class ConvNextBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.dwConv = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=7, padding=3, groups=in_channels, bias=False)
        self.ln = nn.LayerNorm(in_channels)
        self.pwConv4 = nn.Conv2d(in_channels=in_channels, out_channels=4*in_channels, kernel_size=1, bias=False)
        self.gelu = nn.GELU()
        self.pwConv = nn.Conv2d(in_channels=4*in_channels, out_channels=in_channels, kernel_size=1, bias=False)

    def forward(self, x):
        # x: [b, 64, 65, 65]
        original = x

        residual = self.dwConv(x) # [b, 64, 65, 65]
        residual = residual.permute(0, 2, 3, 1) # [b, 65, 65, 64]
        residual = self.ln(residual)
        residual = residual.permute(0, 3, 1, 2) # [b, 64, 65, 65]
        residual = self.pwConv4(residual) # [b, 256, 65, 65]
        residual = self.gelu(residual)
        residual = self.pwConv(residual) # [b, 64, 65, 65]

        return original + residual

In [6]:
class StateEncoder(nn.Module):
    def __init__(self, input_size=65, input_channel=1, hidden_dim=64, embedding_dim=256, layers=4):
        super().__init__()

        self.agent_pwConv = nn.Conv2d(in_channels=input_channel, out_channels=hidden_dim, kernel_size=1, bias=False) # [b, 1, 65, 65] -> [b, 64, 65, 65]
        self.agent_blocks = nn.Sequential(*[ConvNextBlock(hidden_dim) for _ in range(layers)]) # [b, 64, 65, 65]

        self.env_pwConv = nn.Conv2d(in_channels=input_channel, out_channels=hidden_dim, kernel_size=1, bias=False)
        self.env_blocks = nn.Sequential(*[ConvNextBlock(hidden_dim) for _ in range(layers)]) # [b, 64, 65, 65]

        self.agent_fc = nn.Linear(in_features=hidden_dim * input_size * input_size, out_features=embedding_dim) # [b, 64, 65, 65] -> [b, 256]
        self.env_fc = nn.Linear(in_features=hidden_dim * input_size * input_size, out_features=embedding_dim)

        self.agent_dropout = nn.Dropout(p=0.1)
        self.env_dropout = nn.Dropout(p=0.1)

    def forward(self, x):
        # x: [b, 2, 65, 65]
        agent_state = x[:, 0, :, :].unsqueeze(dim=1) # [b, 1, 65, 65]
        env_state = x[:, 1, :, :].unsqueeze(dim=1) # [b, 1, 65, 65]

        agent_rep = self.agent_pwConv(agent_state) # [b, 64, 65, 65]
        agent_rep = self.agent_blocks(agent_rep) # [b, 64, 65, 65]
        agent_rep = agent_rep.flatten(start_dim=1) # [b, 270400]
        agent_rep = self.agent_dropout(agent_rep)
        agent_rep = self.agent_fc(agent_rep) # [b, 256]

        env_rep = self.env_pwConv(env_state) # [b, 64, 65, 65]
        env_rep = self.env_blocks(env_rep) # [b, 64, 65, 65]
        env_rep = env_rep.flatten(start_dim=1) # [b, 270400]
        env_rep = self.env_dropout(env_rep)
        env_rep = self.env_fc(env_rep) # [b, 256]

        rep = agent_rep+env_rep # [b, 256]

        return rep

In [22]:
class StatePredictor(nn.Module):
    def __init__(self, state_dim=256, action_dim=2, velocity_dim=1, hidden_dim=1024):
        super().__init__()
        self.linear1 = nn.Linear(in_features=state_dim+action_dim+velocity_dim, out_features=hidden_dim)
        self.relu1 = nn.ReLU()
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.dropout = nn.Dropout(p=0.1)

        self.linear2 = nn.Linear(in_features=hidden_dim, out_features=hidden_dim)
        self.relu2 = nn.ReLU()
        self.bn2 = nn.BatchNorm1d(hidden_dim)

        self.linear3 = nn.Linear(in_features=hidden_dim, out_features=state_dim)

    def forward(self, x):
        # x: [b, 259]
        original = x[:, :256] # original: [b, 256]

        s = self.linear1(x)
        s = self.relu1(s)
        s = self.bn1(s)
        s = self.dropout(s)

        s = self.linear2(s)
        s = self.relu2(s)
        s = self.bn2(s)

        s = self.linear3(s) # [b, 256]
        
        return original + s


In [23]:
class ExploreJEPA(nn.Module):
    def __init__(self, trajectory_length):
        super().__init__()
        self.trajectory_length = trajectory_length
        self.init_state_encoder = StateEncoder()
        self.later_state_encoder = StateEncoder()
        self.state_predictor = StatePredictor()

    def forward(self, x, a):
        # x: [b, 1, 2, 65, 65]
        # a: [b, 16, 2]
        init_state = x[:, 0, :, :, :] # [b, 2, 65, 65]
        init_state_rep = self.init_state_encoder(init_state) # [b, 256]

        predicted_state_rep = []
        for i in range(self.trajectory_length):
            cur_action = a[:, i, :] # [b, 2]
            if i == 0:
                cur_velocity = torch.zeros(cur_action.shape[0], 1).to(cur_action.device)
                cur_state_rep = init_state_rep
            elif i == 1:
                cur_velocity = torch.norm(predicted_state_rep[i - 1] - init_state_rep, dim=-1, keepdim=True)
                cur_state_rep = predicted_state_rep[i - 1]
            else:
                cur_velocity = torch.norm(predicted_state_rep[i - 1] - predicted_state_rep[i - 2], dim=-1, keepdim=True)
                cur_state_rep = predicted_state_rep[i - 1]
            cur_input = torch.cat([cur_state_rep, cur_action, cur_velocity], dim=1) # [b, 259]

            cur_predicted_state_rep = self.state_predictor(cur_input)
            predicted_state_rep.append(cur_predicted_state_rep)
        return predicted_state_rep

In [24]:
model = ExploreJEPA(16)
model.to("cuda")

ExploreJEPA(
  (init_state_encoder): StateEncoder(
    (agent_pwConv): Conv2d(1, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (agent_blocks): Sequential(
      (0): ConvNextBlock(
        (dwConv): Conv2d(64, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=64, bias=False)
        (ln): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (pwConv4): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (gelu): GELU(approximate='none')
        (pwConv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (1): ConvNextBlock(
        (dwConv): Conv2d(64, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=64, bias=False)
        (ln): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (pwConv4): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (gelu): GELU(approximate='none')
        (pwConv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (2): C

In [25]:
for data in train_loader:
    init_state = data.states[:,0,:,:,:].unsqueeze(dim=1)
    predicted_s = model(init_state.to("cuda"), data.actions.to("cuda"))
    
    break