In [28]:
# パスへの移動
import os
project_path = '/workspace/'
os.chdir(project_path)


# パッケージのロード
import torch
import hydra
from omegaconf import DictConfig
from torch.utils.data import DataLoader
import random
import numpy as np
from src.models.evflownet import EVFlowNet
from src.datasets import DatasetProvider
from enum import Enum, auto
from src.datasets import train_collate
from tqdm import tqdm
from pathlib import Path
from typing import Dict, Any
import os
import time

In [29]:
!pip install hydra-core --upgrade
!pip install hdf5plugin
!sudo apt-get update


Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Hit:1 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease
Hit:2 http://security.ubuntu.com/ubuntu jammy-security InRelease               
Hit:3 http://archive.ubuntu.com/ubuntu jammy InRelease                         
Hit:4 http://archive.ubuntu.com/ubuntu jammy-updates InRelease
Hit:5 http://archive.ubuntu.com/ubuntu jammy-backports InRelease
Reading package lists... Done


In [30]:
# %env HYDRA_FULL_ERROR=1
# !python /workspace/main.py

In [31]:
# hydra用のyamlファイルを読んでargsにロードする。

import os
from hydra import initialize, initialize_config_module, initialize_config_dir, compose
from omegaconf import OmegaConf

with initialize_config_dir(version_base=None, config_dir="/workspace/configs"):
    args = compose(config_name="base")

print(OmegaConf.to_yaml(args))

dataset_path: data
seed: 42
num_epoch: 10
data_loader:
  common:
    num_voxel_bins: 15
  train:
    batch_size: 16
    shuffle: false
  test:
    batch_size: 1
    shuffle: false
train:
  no_batch_norm: false
  initial_learning_rate: 0.01
  weight_decay: 0.0001
  epochs: 10



In [32]:
# 関数の定義
class RepresentationType(Enum):
    VOXEL = auto()
    STEPAN = auto()

def set_seed(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)

def compute_epe_error(pred_flow: torch.Tensor, gt_flow: torch.Tensor):
    '''
    end-point-error (ground truthと予測値の二乗誤差)を計算
    pred_flow: torch.Tensor, Shape: torch.Size([B, 2, 480, 640]) => 予測したオプティカルフローデータ
    gt_flow: torch.Tensor, Shape: torch.Size([B, 2, 480, 640]) => 正解のオプティカルフローデータ
    '''
    epe = torch.mean(torch.mean(torch.norm(pred_flow - gt_flow, p=2, dim=1), dim=(1, 2)), dim=0)
    return epe

def save_optical_flow_to_npy(flow: torch.Tensor, file_name: str):
    '''
    optical flowをnpyファイルに保存
    flow: torch.Tensor, Shape: torch.Size([2, 480, 640]) => オプティカルフローデータ
    file_name: str => ファイル名
    '''
    np.save(f"{file_name}.npy", flow.cpu().numpy())


In [33]:
# モデルの保存ディレクトリの作成・データのロード。

# Create the directory if it doesn't exist
if not os.path.exists('checkpoints'):
    os.makedirs('checkpoints')

set_seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
'''
    ディレクトリ構造:

    data
    ├─test
    |  ├─test_city
    |  |    ├─events_left
    |  |    |   ├─events.h5
    |  |    |   └─rectify_map.h5
    |  |    └─forward_timestamps.txt
    └─train
        ├─zurich_city_11_a
        |    ├─events_left
        |    |       ├─ events.h5
        |    |       └─ rectify_map.h5
        |    ├─ flow_forward
        |    |       ├─ 000134.png
        |    |       |.....
        |    └─ forward_timestamps.txt
        ├─zurich_city_11_b
        └─zurich_city_11_c
    '''

# ------------------
#    Dataloader
# ------------------
loader = DatasetProvider(
    dataset_path=Path(args.dataset_path),
    representation_type=RepresentationType.VOXEL,
    delta_t_ms=100,
    num_bins=4
)
train_set = loader.get_train_dataset()
test_set = loader.get_test_dataset()
collate_fn = train_collate
train_data = DataLoader(train_set,
                                batch_size=args.data_loader.train.batch_size,
                                shuffle=args.data_loader.train.shuffle,
                                collate_fn=collate_fn,
                                drop_last=False)
test_data = DataLoader(test_set,
                                batch_size=args.data_loader.test.batch_size,
                                shuffle=args.data_loader.test.shuffle,
                                collate_fn=collate_fn,
                                drop_last=False)

'''
train data:
    Type of batch: Dict
    Key: seq_name, Type: list
    Key: event_volume, Type: torch.Tensor, Shape: torch.Size([Batch, 4, 480, 640]) => イベントデータのバッチ
    Key: flow_gt, Type: torch.Tensor, Shape: torch.Size([Batch, 2, 480, 640]) => オプティカルフローデータのバッチ
    Key: flow_gt_valid_mask, Type: torch.Tensor, Shape: torch.Size([Batch, 1, 480, 640]) => オプティカルフローデータのvalid. ベースラインでは使わない

test data:
    Type of batch: Dict
    Key: seq_name, Type: list
    Key: event_volume, Type: torch.Tensor, Shape: torch.Size([Batch, 4, 480, 640]) => イベントデータのバッチ
'''

'\ntrain data:\n    Type of batch: Dict\n    Key: seq_name, Type: list\n    Key: event_volume, Type: torch.Tensor, Shape: torch.Size([Batch, 4, 480, 640]) => イベントデータのバッチ\n    Key: flow_gt, Type: torch.Tensor, Shape: torch.Size([Batch, 2, 480, 640]) => オプティカルフローデータのバッチ\n    Key: flow_gt_valid_mask, Type: torch.Tensor, Shape: torch.Size([Batch, 1, 480, 640]) => オプティカルフローデータのvalid. ベースラインでは使わない\n\ntest data:\n    Type of batch: Dict\n    Key: seq_name, Type: list\n    Key: event_volume, Type: torch.Tensor, Shape: torch.Size([Batch, 4, 480, 640]) => イベントデータのバッチ\n'

In [34]:
train_set[0]['event_volume'].size()

torch.Size([4, 480, 640])

In [35]:
def save_model(model):
    current_time = time.strftime("%Y%m%d%H%M%S")
    model_path = f"checkpoints/model_{current_time}.pth"
    torch.save(model.state_dict(), model_path)
    print(f"Model saved to {model_path}")


def train_model(args, model, n_epoch=1):
    # ------------------
    #   optimizer
    # ------------------
    optimizer = torch.optim.Adam(model.parameters(), lr=args.train.initial_learning_rate, weight_decay=args.train.weight_decay)

    # ------------------
    #   Start training
    # ------------------
    model.train()
    for epoch in range(n_epoch):
        total_loss = 0
        print("on epoch: {}".format(epoch+1))
        for i, batch in enumerate(tqdm(train_data)):
            batch: Dict[str, Any]
            event_image = batch["event_volume"].to(device) # [B, 4, 480, 640]
            ground_truth_flow = batch["flow_gt"].to(device) # [B, 2, 480, 640]
            flow = model(event_image) # [B, 2, 480, 640]
            loss: torch.Tensor = compute_epe_error(flow, ground_truth_flow)
            print(f"batch {i} loss: {loss.item()}")
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
        print(f'Epoch {epoch+1}, Loss: {total_loss / len(train_data)}')
        
        current_time = time.strftime("%Y%m%d%H%M%S")
        model_path = f"checkpoints/model_{current_time}.pth"
        torch.save(model.state_dict(), model_path)
        print(f"Model saved to {model_path}")


    # ------------------
    #   Start predicting ()
    # ------------------
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    flow: torch.Tensor = torch.tensor([]).to(device)
    with torch.no_grad():
        print("start test")
        for batch in tqdm(test_data):
            batch: Dict[str, Any]
            event_image = batch["event_volume"].to(device)
            batch_flow = model(event_image) # [1, 2, 480, 640]
            flow = torch.cat((flow, batch_flow), dim=0)  # [N, 2, 480, 640]
        print("test done")
    # ------------------
    #  save submission
    # ------------------
    file_name = "submission"
    save_optical_flow_to_npy(flow, file_name)

In [36]:
# save_model(model)

In [37]:
# model_path_loadのモデルをロード
model_path_load="/workspace/checkpoints/model_20240716200019.pth"
model = EVFlowNet(args.train).to(device)
model.load_state_dict(torch.load(model_path_load, map_location=device))
train_model(args, model)

on epoch: 1


  return F.conv2d(input, weight, bias, self.stride,


batch 0 loss: 4.399824956145882


  1%|          | 1/126 [00:40<1:24:34, 40.60s/it]

batch 1 loss: 7.152964834839732


  2%|▏         | 2/126 [01:10<1:10:49, 34.27s/it]

batch 2 loss: 6.7186343874860945


  2%|▏         | 3/126 [01:38<1:04:13, 31.33s/it]

batch 3 loss: 5.639368628834988


  3%|▎         | 4/126 [02:06<1:01:02, 30.02s/it]

batch 4 loss: 5.907382438774267


  4%|▍         | 5/126 [02:35<1:00:18, 29.90s/it]

batch 5 loss: 5.632936280214679


  5%|▍         | 6/126 [03:06<59:55, 29.96s/it]  

batch 6 loss: 9.556576313100594


  6%|▌         | 7/126 [03:34<58:17, 29.39s/it]

batch 7 loss: 14.262369935276439


  6%|▋         | 8/126 [04:02<57:21, 29.16s/it]

batch 8 loss: 7.132203950092187


  7%|▋         | 9/126 [04:32<57:23, 29.43s/it]

batch 9 loss: 13.841371341413563


  8%|▊         | 10/126 [05:04<58:02, 30.02s/it]

batch 10 loss: 12.056673430148681


  9%|▊         | 11/126 [05:31<56:08, 29.29s/it]

batch 11 loss: 14.199502496065556


 10%|▉         | 12/126 [05:59<54:36, 28.74s/it]

batch 12 loss: 9.296042598989054


 10%|█         | 13/126 [06:28<54:17, 28.82s/it]

batch 13 loss: 5.437187317394529


 11%|█         | 14/126 [06:58<54:17, 29.08s/it]

batch 14 loss: 3.6845877713392228


 12%|█▏        | 15/126 [07:25<53:02, 28.67s/it]

batch 15 loss: 9.767099542688179


 13%|█▎        | 16/126 [07:53<52:01, 28.37s/it]

batch 16 loss: 10.907782628587126


 13%|█▎        | 17/126 [08:22<51:44, 28.48s/it]

batch 17 loss: 7.8279083700211665


 14%|█▍        | 18/126 [08:51<51:31, 28.62s/it]

batch 18 loss: 11.141424262812036


 15%|█▌        | 19/126 [09:19<50:59, 28.60s/it]

batch 19 loss: 11.869404082245651


 16%|█▌        | 20/126 [09:50<51:29, 29.15s/it]

batch 20 loss: 6.774786926332842


 17%|█▋        | 21/126 [10:19<51:11, 29.25s/it]

batch 21 loss: 9.890864818681386


 17%|█▋        | 22/126 [10:51<51:56, 29.97s/it]

batch 22 loss: 7.893690818305061


 18%|█▊        | 23/126 [11:21<51:33, 30.03s/it]

batch 23 loss: 4.78522070251724


 19%|█▉        | 24/126 [11:49<50:08, 29.49s/it]

batch 24 loss: 5.668228328678697


 20%|█▉        | 25/126 [12:19<49:37, 29.48s/it]

batch 25 loss: 5.410503576105226


 21%|██        | 26/126 [12:49<49:28, 29.68s/it]

batch 26 loss: 4.737418406596651


 21%|██▏       | 27/126 [13:18<48:44, 29.54s/it]

batch 27 loss: 3.9099961800583634


 22%|██▏       | 28/126 [13:47<47:44, 29.23s/it]

batch 28 loss: 3.664172650078094


 23%|██▎       | 29/126 [14:16<47:22, 29.31s/it]

batch 29 loss: 4.457859108754231


 24%|██▍       | 30/126 [14:46<47:00, 29.38s/it]

batch 30 loss: 4.649702436644075


 25%|██▍       | 31/126 [15:14<45:52, 28.97s/it]

batch 31 loss: 4.053158111261299


 25%|██▌       | 32/126 [15:42<45:01, 28.74s/it]

batch 32 loss: 3.5939999057426033


 26%|██▌       | 33/126 [16:11<44:51, 28.94s/it]

batch 33 loss: 3.7274204905816664


 27%|██▋       | 34/126 [16:41<44:56, 29.31s/it]

batch 34 loss: 3.8238040236242794


 28%|██▊       | 35/126 [17:10<44:03, 29.05s/it]

batch 35 loss: 2.8594776962740434


 29%|██▊       | 36/126 [17:38<43:11, 28.80s/it]

batch 36 loss: 2.6682329128669475


 29%|██▉       | 37/126 [18:08<43:16, 29.18s/it]

batch 37 loss: 3.1499149015363597


 30%|███       | 38/126 [18:37<42:52, 29.24s/it]

batch 38 loss: 3.4125405328180163


 31%|███       | 39/126 [19:05<41:45, 28.79s/it]

batch 39 loss: 2.918908979639734


 32%|███▏      | 40/126 [19:33<40:55, 28.55s/it]

batch 40 loss: 2.8435596206685547


 33%|███▎      | 41/126 [20:02<40:43, 28.75s/it]

batch 41 loss: 3.005121570963735


 33%|███▎      | 42/126 [20:32<40:42, 29.08s/it]

batch 42 loss: 3.6618017550571507


 34%|███▍      | 43/126 [21:00<39:31, 28.58s/it]

batch 43 loss: 3.798682028979024


 35%|███▍      | 44/126 [21:27<38:37, 28.26s/it]

batch 44 loss: 3.052358308023142


 36%|███▌      | 45/126 [21:56<38:31, 28.54s/it]

batch 45 loss: 3.1890637205051586


 37%|███▋      | 46/126 [22:26<38:31, 28.90s/it]

batch 46 loss: 2.670833631355305


 37%|███▋      | 47/126 [22:55<38:01, 28.88s/it]

batch 47 loss: 2.994720715163725


 38%|███▊      | 48/126 [23:23<37:02, 28.49s/it]

batch 48 loss: 3.21611835392471


 39%|███▉      | 49/126 [23:52<36:47, 28.66s/it]

batch 49 loss: 3.0769197775128463


 40%|███▉      | 50/126 [24:21<36:39, 28.94s/it]

batch 50 loss: 3.476553049833272


 40%|████      | 51/126 [24:49<35:44, 28.60s/it]

batch 51 loss: 3.590483978811304


 41%|████▏     | 52/126 [25:17<34:56, 28.33s/it]

batch 52 loss: 3.5041425511948803


 42%|████▏     | 53/126 [25:47<35:01, 28.78s/it]

batch 53 loss: 3.4439664263503325


 43%|████▎     | 54/126 [26:16<34:53, 29.07s/it]

batch 54 loss: 3.1357340458992566


 44%|████▎     | 55/126 [26:44<34:02, 28.77s/it]

batch 55 loss: 3.2707880605545183


 44%|████▍     | 56/126 [27:12<33:10, 28.44s/it]

batch 56 loss: 2.607447777684838


 45%|████▌     | 57/126 [27:41<32:50, 28.55s/it]

batch 57 loss: 2.7453651204981835


 46%|████▌     | 58/126 [28:10<32:37, 28.79s/it]

batch 58 loss: 2.983012029585729


 47%|████▋     | 59/126 [28:38<31:51, 28.54s/it]

batch 59 loss: 2.7520179467918657
