# パッケージ・パス

In [1]:
# パスへの移動
import os
project_path = '/workspace/'
os.chdir(project_path)


# パッケージのロード
import torch
import hydra
from omegaconf import DictConfig
from torch.utils.data import DataLoader
import random
import numpy as np
from src.models.evflownet import EVFlowNet
from src.datasets import DatasetProvider
from enum import Enum, auto
from src.datasets import train_collate
from tqdm import tqdm
from pathlib import Path
from typing import Dict, Any
import os
import time

In [2]:
!pip install hydra-core --upgrade
!pip install hdf5plugin
!sudo apt-get update


Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Hit:1 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease
Hit:2 http://security.ubuntu.com/ubuntu jammy-security InRelease         
Hit:3 http://archive.ubuntu.com/ubuntu jammy InRelease                   
Hit:4 http://archive.ubuntu.com/ubuntu jammy-updates InRelease
Hit:5 http://archive.ubuntu.com/ubuntu jammy-backports InRelease
Reading package lists... Done


In [3]:
# %env HYDRA_FULL_ERROR=1
# !python /workspace/main.py

In [4]:
# hydra用のyamlファイルを読んでargsにロードする。

import os
from hydra import initialize, initialize_config_module, initialize_config_dir, compose
from omegaconf import OmegaConf

with initialize_config_dir(version_base=None, config_dir="/workspace/configs"):
    args = compose(config_name="base")

print(OmegaConf.to_yaml(args))

dataset_path: data
seed: 42
num_epoch: 10
data_loader:
  common:
    num_voxel_bins: 15
  train:
    batch_size: 8
    shuffle: true
  test:
    batch_size: 1
    shuffle: false
train:
  no_batch_norm: false
  initial_learning_rate: 0.01
  weight_decay: 0.0001
  epochs: 10



In [5]:
# 関数の定義
class RepresentationType(Enum):
    VOXEL = auto()
    STEPAN = auto()

def set_seed(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)

def compute_epe_error(pred_flow: torch.Tensor, gt_flow: torch.Tensor):
    '''
    end-point-error (ground truthと予測値の二乗誤差)を計算
    pred_flow: torch.Tensor, Shape: torch.Size([B, 2, 480, 640]) => 予測したオプティカルフローデータ
    gt_flow: torch.Tensor, Shape: torch.Size([B, 2, 480, 640]) => 正解のオプティカルフローデータ
    '''
    epe = torch.mean(torch.mean(torch.norm(pred_flow - gt_flow, p=2, dim=1), dim=(1, 2)), dim=0)
    return epe

def save_optical_flow_to_npy(flow: torch.Tensor, file_name: str):
    '''
    optical flowをnpyファイルに保存
    flow: torch.Tensor, Shape: torch.Size([2, 480, 640]) => オプティカルフローデータ
    file_name: str => ファイル名
    '''
    np.save(f"{file_name}.npy", flow.cpu().numpy())


# データローダ

In [15]:
# モデルの保存ディレクトリの作成・データのロード。

# Create the directory if it doesn't exist
if not os.path.exists('checkpoints'):
    os.makedirs('checkpoints')


set_seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
'''
    ディレクトリ構造:

    data
    ├─test
    |  ├─test_city
    |  |    ├─events_left
    |  |    |   ├─events.h5
    |  |    |   └─rectify_map.h5
    |  |    └─forward_timestamps.txt
    └─train
        ├─zurich_city_11_a
        |    ├─events_left
        |    |       ├─ events.h5
        |    |       └─ rectify_map.h5
        |    ├─ flow_forward
        |    |       ├─ 000134.png
        |    |       |.....
        |    └─ forward_timestamps.txt
        ├─zurich_city_11_b
        └─zurich_city_11_c
    '''


# ------------------
#    Dataloader
# ------------------
loader = DatasetProvider(
    dataset_path=Path(args.dataset_path),
    representation_type=RepresentationType.VOXEL,
    delta_t_ms=100,
    num_bins=4
)
train_set = loader.get_train_dataset()
test_set = loader.get_test_dataset()
collate_fn = train_collate


train_data = DataLoader(train_set,
                                batch_size=args.data_loader.train.batch_size,
                                shuffle=args.data_loader.train.shuffle,
                                collate_fn=collate_fn,
                                drop_last=False)
test_data = DataLoader(test_set,
                                batch_size=args.data_loader.test.batch_size,
                                shuffle=args.data_loader.test.shuffle,
                                collate_fn=collate_fn,
                                drop_last=False)

'''
train data:
    Type of batch: Dict
    Key: seq_name, Type: list
    Key: event_volume, Type: torch.Tensor, Shape: torch.Size([Batch, 4, 480, 640]) => イベントデータのバッチ
    Key: flow_gt, Type: torch.Tensor, Shape: torch.Size([Batch, 2, 480, 640]) => オプティカルフローデータのバッチ
    Key: flow_gt_valid_mask, Type: torch.Tensor, Shape: torch.Size([Batch, 1, 480, 640]) => オプティカルフローデータのvalid. ベースラインでは使わない

test data:
    Type of batch: Dict
    Key: seq_name, Type: list
    Key: event_volume, Type: torch.Tensor, Shape: torch.Size([Batch, 4, 480, 640]) => イベントデータのバッチ
'''

'\ntrain data:\n    Type of batch: Dict\n    Key: seq_name, Type: list\n    Key: event_volume, Type: torch.Tensor, Shape: torch.Size([Batch, 4, 480, 640]) => イベントデータのバッチ\n    Key: flow_gt, Type: torch.Tensor, Shape: torch.Size([Batch, 2, 480, 640]) => オプティカルフローデータのバッチ\n    Key: flow_gt_valid_mask, Type: torch.Tensor, Shape: torch.Size([Batch, 1, 480, 640]) => オプティカルフローデータのvalid. ベースラインでは使わない\n\ntest data:\n    Type of batch: Dict\n    Key: seq_name, Type: list\n    Key: event_volume, Type: torch.Tensor, Shape: torch.Size([Batch, 4, 480, 640]) => イベントデータのバッチ\n'

In [7]:
def save_model(model):
    current_time = time.strftime("%Y%m%d%H%M%S")
    model_path = f"checkpoints/model_{current_time}.pth"
    torch.save(model.state_dict(), model_path)
    print(f"Model saved to {model_path}")

    # ------------------
    #   Start predicting ()
    # ------------------
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    flow: torch.Tensor = torch.tensor([]).to(device)
    with torch.no_grad():
        print("start test")
        for batch in tqdm(test_data):
            batch: Dict[str, Any]
            event_image = batch["event_volume"].to(device)
            batch_flow = model(event_image) # [1, 2, 480, 640]
            flow = torch.cat((flow, batch_flow), dim=0)  # [N, 2, 480, 640]
        print("test done")
    # ------------------
    #  save submission
    # ------------------
    file_name = "submission"
    save_optical_flow_to_npy(flow, file_name)
    print("Submission saved")


    # ------------------
    #   Start predicting ()
    # ------------------
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    flow: torch.Tensor = torch.tensor([]).to(device)
    with torch.no_grad():
        print("start test")
        for batch in tqdm(test_data):
            batch: Dict[str, Any]
            event_image = batch["event_volume"].to(device)
            batch_flow = model(event_image) # [1, 2, 480, 640]
            flow = torch.cat((flow, batch_flow), dim=0)  # [N, 2, 480, 640]
        print("test done")
    # ------------------
    #  save submission
    # ------------------
    file_name = "submission"
    save_optical_flow_to_npy(flow, file_name)

# トレーニング関数

In [8]:
# 学習率スケジューラー
from torch.optim import lr_scheduler

def train_model(args, model, n_epoch=1):
    # ------------------
    #   optimizer
    # ------------------
    optimizer = torch.optim.Adam(model.parameters(), lr=args.train.initial_learning_rate, weight_decay=args.train.weight_decay)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8)

    # ------------------
    #   Start training
    # ------------------
    model.train()
    for epoch in range(n_epoch):
        total_loss = 0
        print("on epoch: {}".format(epoch+1))
        for i, batch in enumerate(tqdm(train_data)):
            batch: Dict[str, Any]
            event_image = batch["event_volume"].to(device) # [B, 4, 480, 640]
            ground_truth_flow = batch["flow_gt"].to(device) # [B, 2, 480, 640]
            flow = model(event_image) # [B, 2, 480, 640]
            loss: torch.Tensor = compute_epe_error(flow, ground_truth_flow)
            print(f"batch {i} loss: {loss.item()}")
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step(loss.item())
            print(f"learining rate: {scheduler.get_last_lr()}")

            # new_lr = scheduler(i)
            # set_lr(new_lr, optimizer)

            if loss < 2.5:
                torch.save(model.state_dict(), "checkpoints/model_under2_5.pth")
                print("tmp model saved!")

            total_loss += loss.item()
        average_loss = total_loss / len(train_data)
        print(f'Epoch {epoch+1}, Loss: {average_loss}')
        
        current_time = time.strftime("%Y%m%d%H%M%S")
        model_path = f"checkpoints/model_{current_time}.pth"
        torch.save(model.state_dict(), model_path)
        print(f"Model saved to {model_path}")


    # ------------------
    #   Start predicting ()
    # ------------------
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    flow: torch.Tensor = torch.tensor([]).to(device)
    with torch.no_grad():
        print("start test")
        for batch in tqdm(test_data):
            batch: Dict[str, Any]
            event_image = batch["event_volume"].to(device)
            batch_flow = model(event_image) # [1, 2, 480, 640]
            flow = torch.cat((flow, batch_flow), dim=0)  # [N, 2, 480, 640]
        print("test done")
    # ------------------
    #  save submission
    # ------------------
    file_name = "submission"
    save_optical_flow_to_npy(flow, file_name)

In [9]:
def evaluate(model, file_name="submission"):
    # ------------------
    #   Start predicting ()
    # ------------------
    model.load_state_dict(torch.load("checkpoints/model_tmp.pth", map_location=device))
    model.eval()
    flow: torch.Tensor = torch.tensor([]).to(device)
    with torch.no_grad():
        print("start test")
        for batch in tqdm(test_data):
            batch: Dict[str, Any]
            event_image = batch["event_volume"].to(device)
            batch_flow = model(event_image) # [1, 2, 480, 640]
            flow = torch.cat((flow, batch_flow), dim=0)  # [N, 2, 480, 640]
        print("test done")
    # ------------------
    #  save submission
    # ------------------
    save_optical_flow_to_npy(flow, file_name)

# 2d ver.

In [10]:
# model_path_loadのモデルをロード
model_path_load="/workspace/checkpoints/model_20240717025440.pth"
model = EVFlowNet(args.train).to(device)
model.load_state_dict(torch.load(model_path_load, map_location=device))
train_model(args, model)

evaluate(model)

on epoch: 1


  return F.conv2d(input, weight, bias, self.stride,


batch 0 loss: 2.882945358133317


  0%|          | 1/252 [00:07<31:34,  7.55s/it]

learining rate: [0.01]
batch 1 loss: 7.964097044271981


  1%|          | 2/252 [00:14<30:51,  7.41s/it]

learining rate: [0.01]
batch 2 loss: 5.299360674265261


  1%|          | 3/252 [00:22<30:35,  7.37s/it]

learining rate: [0.01]
batch 3 loss: 18.10494331275967


  2%|▏         | 4/252 [00:28<29:25,  7.12s/it]

learining rate: [0.008]
batch 4 loss: 7.683758577316789


  2%|▏         | 5/252 [00:36<29:49,  7.25s/it]

learining rate: [0.01]
batch 5 loss: 5.260834195917714


  2%|▏         | 6/252 [00:42<28:47,  7.02s/it]

learining rate: [0.01]
batch 6 loss: 4.7421369699083895


  3%|▎         | 7/252 [00:50<29:12,  7.15s/it]

learining rate: [0.01]
batch 7 loss: 4.086549839775708


  3%|▎         | 8/252 [00:57<28:36,  7.04s/it]

learining rate: [0.01]
batch 8 loss: 4.373871901906492


  4%|▎         | 9/252 [01:04<28:45,  7.10s/it]

learining rate: [0.01]
batch 9 loss: 4.448087942899019


  4%|▍         | 10/252 [01:11<28:18,  7.02s/it]

learining rate: [0.01]
batch 10 loss: 4.840691375202018


  4%|▍         | 11/252 [01:18<28:59,  7.22s/it]

learining rate: [0.01]
batch 11 loss: 3.6233074956258218


  5%|▍         | 12/252 [01:25<28:26,  7.11s/it]

learining rate: [0.01]
batch 12 loss: 3.97656004346325


  5%|▌         | 13/252 [01:33<28:30,  7.16s/it]

learining rate: [0.01]
batch 13 loss: 3.815441402317383


  6%|▌         | 14/252 [01:39<27:55,  7.04s/it]

learining rate: [0.01]
batch 14 loss: 5.008849801620603


  6%|▌         | 15/252 [01:46<27:54,  7.07s/it]

learining rate: [0.01]
batch 15 loss: 5.0285678544476085


  6%|▋         | 16/252 [01:53<27:20,  6.95s/it]

learining rate: [0.01]
batch 16 loss: 4.287112455461329


  7%|▋         | 17/252 [02:00<27:31,  7.03s/it]

learining rate: [0.01]
batch 17 loss: 5.7525008582231845


  7%|▋         | 18/252 [02:07<27:20,  7.01s/it]

learining rate: [0.01]
batch 18 loss: 3.441233174903262


  8%|▊         | 19/252 [02:15<27:28,  7.07s/it]

learining rate: [0.01]
batch 19 loss: 5.941102144350193


  8%|▊         | 20/252 [02:22<27:14,  7.05s/it]

learining rate: [0.01]
batch 20 loss: 5.214999859514849


  8%|▊         | 21/252 [02:29<27:21,  7.11s/it]

learining rate: [0.01]
batch 21 loss: 4.324956709439748


  9%|▊         | 22/252 [02:36<26:53,  7.02s/it]

learining rate: [0.01]
batch 22 loss: 5.291048148924431


  9%|▉         | 23/252 [02:43<27:00,  7.08s/it]

learining rate: [0.01]
batch 23 loss: 3.7651368571835664


 10%|▉         | 24/252 [02:49<26:27,  6.96s/it]

learining rate: [0.01]
batch 24 loss: 4.02173058964218


 10%|▉         | 25/252 [02:57<26:44,  7.07s/it]

learining rate: [0.01]
batch 25 loss: 3.751424651364447


 10%|█         | 26/252 [03:04<26:19,  6.99s/it]

learining rate: [0.01]
batch 26 loss: 3.535908383829792


 11%|█         | 27/252 [03:11<26:20,  7.02s/it]

learining rate: [0.01]
batch 27 loss: 3.05007380000103


 11%|█         | 28/252 [03:18<26:04,  6.98s/it]

learining rate: [0.01]
batch 28 loss: 3.194563767934338


 12%|█▏        | 29/252 [03:25<26:30,  7.13s/it]

learining rate: [0.01]
batch 29 loss: 3.4836632376888983


 12%|█▏        | 30/252 [03:32<25:55,  7.01s/it]

learining rate: [0.01]
batch 30 loss: 4.092746173262425


 12%|█▏        | 31/252 [03:39<26:15,  7.13s/it]

learining rate: [0.01]
batch 31 loss: 3.9485774616657414


 13%|█▎        | 32/252 [03:46<25:49,  7.04s/it]

learining rate: [0.01]
batch 32 loss: 4.602485172487317


 13%|█▎        | 33/252 [03:53<26:07,  7.16s/it]

learining rate: [0.01]
batch 33 loss: 4.399599901881779


 13%|█▎        | 34/252 [04:01<25:53,  7.13s/it]

learining rate: [0.01]
batch 34 loss: 2.7344665812853637


 14%|█▍        | 35/252 [04:08<25:53,  7.16s/it]

learining rate: [0.01]
batch 35 loss: 4.867302948284415


 14%|█▍        | 36/252 [04:15<25:45,  7.16s/it]

learining rate: [0.01]
batch 36 loss: 3.968220241412511


 15%|█▍        | 37/252 [04:22<25:58,  7.25s/it]

learining rate: [0.01]
batch 37 loss: 3.337376104296551


 15%|█▌        | 38/252 [04:29<25:36,  7.18s/it]

learining rate: [0.01]
batch 38 loss: 2.9499502803979327


 15%|█▌        | 39/252 [04:37<25:32,  7.19s/it]

learining rate: [0.01]
batch 39 loss: 5.094045655188147


 16%|█▌        | 40/252 [04:43<24:53,  7.05s/it]

learining rate: [0.01]
batch 40 loss: 5.597382717867227


 16%|█▋        | 41/252 [04:51<25:11,  7.16s/it]

learining rate: [0.01]
batch 41 loss: 4.969121593829987


 17%|█▋        | 42/252 [04:58<24:45,  7.07s/it]

learining rate: [0.01]
batch 42 loss: 8.669046646938838


 17%|█▋        | 43/252 [05:05<24:53,  7.15s/it]

learining rate: [0.01]
batch 43 loss: 10.950356627211642


 17%|█▋        | 44/252 [05:12<24:34,  7.09s/it]

learining rate: [0.008]
batch 44 loss: 5.047271486005005


 18%|█▊        | 45/252 [05:19<24:40,  7.15s/it]

learining rate: [0.01]
batch 45 loss: 5.260150967648322


 18%|█▊        | 46/252 [05:26<24:15,  7.07s/it]

learining rate: [0.01]
batch 46 loss: 6.425229026386621


 19%|█▊        | 47/252 [05:34<24:32,  7.18s/it]

learining rate: [0.01]
batch 47 loss: 7.743779626559682


 19%|█▉        | 48/252 [05:40<24:09,  7.10s/it]

learining rate: [0.01]
batch 48 loss: 8.07796806518954


 19%|█▉        | 49/252 [05:48<24:26,  7.22s/it]

learining rate: [0.01]
batch 49 loss: 5.978456284844864


 20%|█▉        | 50/252 [05:55<24:15,  7.20s/it]

learining rate: [0.01]
batch 50 loss: 5.937716369099247


 20%|██        | 51/252 [06:02<24:11,  7.22s/it]

learining rate: [0.01]
batch 51 loss: 7.390540130913894


 21%|██        | 52/252 [06:09<23:45,  7.13s/it]

learining rate: [0.01]
batch 52 loss: 4.479421294220755


 21%|██        | 53/252 [06:17<23:55,  7.21s/it]

learining rate: [0.01]
batch 53 loss: 4.745053238498862


 21%|██▏       | 54/252 [06:24<23:38,  7.16s/it]

learining rate: [0.01]
batch 54 loss: 4.91549917844251


 22%|██▏       | 55/252 [06:31<23:51,  7.27s/it]

learining rate: [0.01]
batch 55 loss: 3.3771392605664596


 22%|██▏       | 56/252 [06:38<23:15,  7.12s/it]

learining rate: [0.01]
batch 56 loss: 3.1500273137742285


 23%|██▎       | 57/252 [06:45<23:16,  7.16s/it]

learining rate: [0.01]
batch 57 loss: 4.073938788890281


 23%|██▎       | 58/252 [06:52<22:55,  7.09s/it]

learining rate: [0.01]
batch 58 loss: 4.934104400144292


 23%|██▎       | 59/252 [07:00<23:24,  7.28s/it]

learining rate: [0.01]
batch 59 loss: 3.9766089938090596


 24%|██▍       | 60/252 [07:07<23:00,  7.19s/it]

learining rate: [0.01]
batch 60 loss: 3.8342416098019743


 24%|██▍       | 61/252 [07:14<22:59,  7.22s/it]

learining rate: [0.01]
batch 61 loss: 3.9332272627098313


 25%|██▍       | 62/252 [07:21<22:35,  7.14s/it]

learining rate: [0.01]
batch 62 loss: 3.494400602743087


 25%|██▌       | 63/252 [07:28<22:32,  7.16s/it]

learining rate: [0.01]
batch 63 loss: 3.0731098612850234


 25%|██▌       | 64/252 [07:35<22:05,  7.05s/it]

learining rate: [0.01]
batch 64 loss: 3.527004496097154


 26%|██▌       | 65/252 [07:43<22:24,  7.19s/it]

learining rate: [0.01]
batch 65 loss: 4.266165506536865


 26%|██▌       | 66/252 [07:50<22:14,  7.17s/it]

learining rate: [0.01]
batch 66 loss: 3.6530586515278807


 27%|██▋       | 67/252 [07:57<22:09,  7.19s/it]

learining rate: [0.01]
batch 67 loss: 3.87804536443313


 27%|██▋       | 68/252 [08:04<21:58,  7.16s/it]

learining rate: [0.01]
batch 68 loss: 2.962055991475848


 27%|██▋       | 69/252 [08:11<21:56,  7.19s/it]

learining rate: [0.01]
batch 69 loss: 3.710680242689651


 28%|██▊       | 70/252 [08:18<21:21,  7.04s/it]

learining rate: [0.01]
batch 70 loss: 3.7192604096921063


 28%|██▊       | 71/252 [08:25<21:22,  7.09s/it]

learining rate: [0.01]
batch 71 loss: 6.718011441368278


 29%|██▊       | 72/252 [08:32<21:14,  7.08s/it]

learining rate: [0.01]
batch 72 loss: 4.495234175958906


 29%|██▉       | 73/252 [08:41<22:22,  7.50s/it]

learining rate: [0.01]
batch 73 loss: 4.067910314290202


 29%|██▉       | 74/252 [08:49<22:32,  7.60s/it]

learining rate: [0.01]
batch 74 loss: 4.6269098067706755


 30%|██▉       | 75/252 [08:57<23:00,  7.80s/it]

learining rate: [0.01]
batch 75 loss: 3.6088794854328192


 30%|███       | 76/252 [09:05<22:45,  7.76s/it]

learining rate: [0.01]
batch 76 loss: 3.2106321681389014


 31%|███       | 77/252 [09:14<23:56,  8.21s/it]

learining rate: [0.01]
batch 77 loss: 3.196798476080245


 31%|███       | 78/252 [09:22<23:50,  8.22s/it]

learining rate: [0.01]
batch 78 loss: 3.7543044536014554


 31%|███▏      | 79/252 [09:31<23:54,  8.29s/it]

learining rate: [0.01]
batch 79 loss: 3.396085302742821


 32%|███▏      | 80/252 [09:38<22:59,  8.02s/it]

learining rate: [0.01]
batch 80 loss: 3.1481996026339907


 32%|███▏      | 81/252 [09:46<22:52,  8.02s/it]

learining rate: [0.01]
batch 81 loss: 3.308911441148612


 33%|███▎      | 82/252 [09:54<22:24,  7.91s/it]

learining rate: [0.01]
batch 82 loss: 3.540777340672311


 33%|███▎      | 83/252 [10:02<22:38,  8.04s/it]

learining rate: [0.01]
batch 83 loss: 2.87831349869462


 33%|███▎      | 84/252 [10:09<21:50,  7.80s/it]

learining rate: [0.01]
batch 84 loss: 2.663063923258413


 34%|███▎      | 85/252 [10:18<22:18,  8.01s/it]

learining rate: [0.01]
batch 85 loss: 3.2195839699185864


 34%|███▍      | 86/252 [10:26<22:28,  8.12s/it]

learining rate: [0.01]
batch 86 loss: 3.4985140490441147


 35%|███▍      | 87/252 [10:34<22:29,  8.18s/it]

learining rate: [0.01]
batch 87 loss: 2.9128209585324187


 35%|███▍      | 88/252 [10:42<21:49,  7.98s/it]

learining rate: [0.01]
batch 88 loss: 3.1589237604917457


 35%|███▌      | 89/252 [10:50<21:37,  7.96s/it]

learining rate: [0.01]
batch 89 loss: 2.6419090076627962


 36%|███▌      | 90/252 [10:57<21:02,  7.79s/it]

learining rate: [0.01]
batch 90 loss: 2.997593692410375


 36%|███▌      | 91/252 [11:06<21:23,  7.97s/it]

learining rate: [0.01]
batch 91 loss: 2.972639736761937


 37%|███▋      | 92/252 [11:14<21:13,  7.96s/it]

learining rate: [0.01]
batch 92 loss: 2.743314632913511


 37%|███▋      | 93/252 [11:22<21:21,  8.06s/it]

learining rate: [0.01]
batch 93 loss: 3.4517148381122817


 37%|███▋      | 94/252 [11:30<20:59,  7.97s/it]

learining rate: [0.01]
batch 94 loss: 2.686612651161396


 38%|███▊      | 95/252 [11:38<20:54,  7.99s/it]

learining rate: [0.01]
batch 95 loss: 3.019150467968499


 38%|███▊      | 96/252 [11:45<20:22,  7.83s/it]

learining rate: [0.01]
batch 96 loss: 2.3688422025110456


 38%|███▊      | 97/252 [11:53<20:33,  7.96s/it]

learining rate: [0.01]
tmp model saved!
batch 97 loss: 2.8719154911197293


 39%|███▉      | 98/252 [12:01<19:56,  7.77s/it]

learining rate: [0.01]
batch 98 loss: 2.497013067622177


 39%|███▉      | 99/252 [12:09<20:07,  7.89s/it]

learining rate: [0.01]
tmp model saved!
batch 99 loss: 2.485855153764672


 40%|███▉      | 100/252 [12:16<19:46,  7.80s/it]

learining rate: [0.01]
tmp model saved!
batch 100 loss: 3.125245763012146


 40%|████      | 101/252 [12:25<19:50,  7.88s/it]

learining rate: [0.01]
batch 101 loss: 3.806483623868172


 40%|████      | 102/252 [12:32<19:30,  7.81s/it]

learining rate: [0.01]
batch 102 loss: 3.0026772172335043


 41%|████      | 103/252 [12:40<19:36,  7.90s/it]

learining rate: [0.01]
batch 103 loss: 2.863824714157296


 41%|████▏     | 104/252 [12:48<19:17,  7.82s/it]

learining rate: [0.01]
batch 104 loss: 3.2435815017218057


 42%|████▏     | 105/252 [12:56<19:16,  7.86s/it]

learining rate: [0.01]
batch 105 loss: 2.827616429327078


 42%|████▏     | 106/252 [13:03<18:50,  7.74s/it]

learining rate: [0.01]
batch 106 loss: 3.0333303765286637


 42%|████▏     | 107/252 [13:11<18:59,  7.86s/it]

learining rate: [0.01]
batch 107 loss: 3.0220516082544817


 43%|████▎     | 108/252 [13:19<18:36,  7.75s/it]

learining rate: [0.01]
batch 108 loss: 2.832284098578761


 43%|████▎     | 109/252 [13:27<18:43,  7.86s/it]

learining rate: [0.01]
batch 109 loss: 3.0104732742235565


 44%|████▎     | 110/252 [13:34<18:13,  7.70s/it]

learining rate: [0.01]
batch 110 loss: 3.4243480705394864


 44%|████▍     | 111/252 [13:43<18:26,  7.85s/it]

learining rate: [0.01]
batch 111 loss: 3.3144665396877215


 44%|████▍     | 112/252 [13:50<18:04,  7.75s/it]

learining rate: [0.01]
batch 112 loss: 2.690398103862759


 45%|████▍     | 113/252 [13:58<18:09,  7.84s/it]

learining rate: [0.01]
batch 113 loss: 2.456297482322154


 45%|████▌     | 114/252 [14:06<17:57,  7.81s/it]

learining rate: [0.01]
tmp model saved!
batch 114 loss: 3.6267661794991666


 46%|████▌     | 115/252 [14:14<18:08,  7.94s/it]

learining rate: [0.01]
batch 115 loss: 2.532274716996543


 46%|████▌     | 116/252 [14:22<17:38,  7.78s/it]

learining rate: [0.01]
batch 116 loss: 3.5066904433395227


 46%|████▋     | 117/252 [14:30<17:49,  7.92s/it]

learining rate: [0.01]
batch 117 loss: 4.0125882728811


 46%|████▋     | 117/252 [14:37<16:53,  7.50s/it]


KeyboardInterrupt: 

In [12]:
# 学習率スケジューラー
from torch.optim import lr_scheduler

def train_model(args, model, n_epoch=1):
    # ------------------
    #   optimizer
    # ------------------
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=args.train.weight_decay)

    # ------------------
    #   Start training
    # ------------------
    model.train()
    for epoch in range(n_epoch):
        total_loss = 0
        print("on epoch: {}".format(epoch+1))
        for i, batch in enumerate(tqdm(train_data)):
            batch: Dict[str, Any]
            event_image = batch["event_volume"].to(device) # [B, 4, 480, 640]
            ground_truth_flow = batch["flow_gt"].to(device) # [B, 2, 480, 640]
            flow = model(event_image) # [B, 2, 480, 640]
            loss: torch.Tensor = compute_epe_error(flow, ground_truth_flow)
            print(f"batch {i} loss: {loss.item()}")
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


            # new_lr = scheduler(i)
            # set_lr(new_lr, optimizer)

            if loss < 2.3:
                current_time = time.strftime("%Y%m%d%H%M%S")
                torch.save(model.state_dict(), f"checkpoints/model_tmp_{current_time}.pth")
                print("tmp model saved!")

            total_loss += loss.item()
        average_loss = total_loss / len(train_data)
        print(f'Epoch {epoch+1}, Loss: {average_loss}')
        
        current_time = time.strftime("%Y%m%d%H%M%S")
        model_path = f"checkpoints/model_{current_time}.pth"
        torch.save(model.state_dict(), model_path)
        print(f"Model saved to {model_path}")


    # ------------------
    #   Start predicting ()
    # ------------------
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    flow: torch.Tensor = torch.tensor([]).to(device)
    with torch.no_grad():
        print("start test")
        for batch in tqdm(test_data):
            batch: Dict[str, Any]
            event_image = batch["event_volume"].to(device)
            batch_flow = model(event_image) # [1, 2, 480, 640]
            flow = torch.cat((flow, batch_flow), dim=0)  # [N, 2, 480, 640]
        print("test done")
    # ------------------
    #  save submission
    # ------------------
    file_name = "submission"
    save_optical_flow_to_npy(flow, file_name)

In [14]:
evaluate(model)

start test


  return F.conv2d(input, weight, bias, self.stride,
100%|██████████| 97/97 [00:09<00:00, 10.75it/s]


test done


In [13]:
# model_path_loadのモデルをロード
model_path_load="/workspace/checkpoints/model_under2_5.pth"
model = EVFlowNet(args.train).to(device)
model.load_state_dict(torch.load(model_path_load, map_location=device))
train_model(args, model)

evaluate(model)

on epoch: 1


  0%|          | 0/252 [00:00<?, ?it/s]

batch 0 loss: 3.841246244727871


  0%|          | 1/252 [00:08<34:07,  8.16s/it]

batch 1 loss: 3.2783570480625484


  1%|          | 2/252 [00:16<34:46,  8.35s/it]

batch 2 loss: 2.756024718486039


  1%|          | 3/252 [00:24<33:06,  7.98s/it]

batch 3 loss: 3.042670696728262


  2%|▏         | 4/252 [00:31<31:56,  7.73s/it]

batch 4 loss: 3.1549487278275627


  2%|▏         | 5/252 [00:39<31:29,  7.65s/it]

batch 5 loss: 3.102080711684323


  2%|▏         | 6/252 [00:46<31:07,  7.59s/it]

batch 6 loss: 3.4198335031421827


  3%|▎         | 7/252 [00:54<31:02,  7.60s/it]

batch 7 loss: 2.492196407364289


  3%|▎         | 8/252 [01:01<30:46,  7.57s/it]

batch 8 loss: 2.5631867079165738


  4%|▎         | 9/252 [01:09<30:37,  7.56s/it]

batch 9 loss: 3.0970622424020373


  4%|▍         | 10/252 [01:16<30:33,  7.58s/it]

batch 10 loss: 3.78546129270989


  4%|▍         | 11/252 [01:24<31:10,  7.76s/it]

batch 11 loss: 2.4223092064893144


  5%|▍         | 12/252 [01:33<31:35,  7.90s/it]

batch 12 loss: 2.6848495692542222


  5%|▌         | 13/252 [01:40<30:59,  7.78s/it]

batch 13 loss: 2.565180346791974


  6%|▌         | 14/252 [01:48<30:38,  7.72s/it]

batch 14 loss: 2.822029143679829


  6%|▌         | 15/252 [01:55<30:17,  7.67s/it]

batch 15 loss: 2.8342761820523905


  6%|▋         | 16/252 [02:03<30:03,  7.64s/it]

batch 16 loss: 2.9553209782934706


  7%|▋         | 17/252 [02:11<30:21,  7.75s/it]

batch 17 loss: 2.8302427128149823


  7%|▋         | 18/252 [02:19<30:09,  7.73s/it]

batch 18 loss: 2.5221174348862347


  8%|▊         | 19/252 [02:26<30:11,  7.78s/it]

batch 19 loss: 2.4060743477707485


  8%|▊         | 20/252 [02:34<29:52,  7.73s/it]

batch 20 loss: 2.764015071058996


  8%|▊         | 21/252 [02:42<29:25,  7.64s/it]

batch 21 loss: 2.5532110334102085


  9%|▊         | 22/252 [02:49<29:14,  7.63s/it]

batch 22 loss: 2.7621452364658987


  9%|▉         | 23/252 [02:57<29:05,  7.62s/it]

batch 23 loss: 2.5010559273926485


 10%|▉         | 24/252 [03:04<28:47,  7.58s/it]

batch 24 loss: 2.6034595231529045


 10%|▉         | 25/252 [03:12<28:38,  7.57s/it]

batch 25 loss: 2.4226811413212253


 10%|█         | 26/252 [03:19<28:12,  7.49s/it]

batch 26 loss: 2.7851587019661768


 11%|█         | 27/252 [03:27<28:38,  7.64s/it]

batch 27 loss: 4.922676707315945


 11%|█         | 28/252 [03:35<29:24,  7.88s/it]

batch 28 loss: 2.2991096445634875


 12%|█▏        | 29/252 [03:44<29:43,  8.00s/it]

tmp model saved!
batch 29 loss: 3.3721811148963434


 12%|█▏        | 30/252 [03:52<29:49,  8.06s/it]

batch 30 loss: 2.650776119806694


 12%|█▏        | 31/252 [04:00<29:42,  8.06s/it]

batch 31 loss: 3.424701440263794


 13%|█▎        | 32/252 [04:08<29:46,  8.12s/it]

batch 32 loss: 2.521674920402023


 13%|█▎        | 33/252 [04:21<28:52,  7.91s/it]


KeyboardInterrupt: 

In [16]:
# model_path_loadのモデルをロード
model_path_load="/workspace/checkpoints/model_under2_5.pth"
model = EVFlowNet(args.train).to(device)
model.load_state_dict(torch.load(model_path_load, map_location=device))
train_model(args, model)

evaluate(model)

on epoch: 1


  0%|          | 0/252 [00:00<?, ?it/s]

batch 0 loss: 2.0533226726224836


  0%|          | 0/252 [00:29<?, ?it/s]


KeyboardInterrupt: 

# Mine version

In [13]:
def sliding_window_collate(batch, window_size=2, step_size=1):
    # 初期化
    seq_names = []
    event_volumes = []
    flow_gts = []
    flow_gt_valid_masks = []
    
    # バッチの各要素に対して処理を行う
    for item in batch:
        seq_names.append(item['seq_name'])
        event_volumes.append(item['event_volume'])
        flow_gts.append(item['flow_gt'])
        flow_gt_valid_masks.append(item['flow_gt_valid_mask'])
    
    # スライディングウィンドウを適用する
    sliding_batches = []
    for i in range(0, len(event_volumes) - window_size + 1, step_size):
        sliding_batch = {
            'seq_name': seq_names[i:i + window_size],
            'event_volume': torch.stack(event_volumes[i:i + window_size]),
            'flow_gt': torch.stack(flow_gts[i:i + window_size]),
            'flow_gt_valid_mask': torch.stack(flow_gt_valid_masks[i:i + window_size]),
        }
        sliding_batches.append(sliding_batch)
    
    return sliding_batches

# DataLoaderのcollate_fnを更新する
train_data = DataLoader(train_set,
                                batch_size=args.data_loader.train.batch_size,
                                shuffle=args.data_loader.train.shuffle,
                                collate_fn=lambda batch: sliding_window_collate(batch, window_size=2, step_size=1),
                                drop_last=False)

test_data = DataLoader(test_set,
                                batch_size=args.data_loader.test.batch_size,
                                shuffle=args.data_loader.test.shuffle,
                                collate_fn=lambda batch: sliding_window_collate(batch, window_size=2, step_size=1),
                                drop_last=False)


In [14]:
def slide_and_concatenate(images):
    # images: [batch_size, ch=4, width=480, height=640]
    batch_size, ch, width, height = images.shape
    
    # チャンネル数を2倍に増やすためのテンソル
    concatenated = torch.empty((batch_size - 1, ch * 2, width, height), dtype=images.dtype, device=images.device)
    
    for i in range(batch_size - 1):
        concatenated[i] = torch.cat((images[i], images[i + 1]), dim=0)
    
    return concatenated

In [None]:
def train_my_model(args, model, n_epoch=1):
    # ------------------
    #   optimizer
    # ------------------
    optimizer = torch.optim.Adam(model.parameters(), lr=args.train.initial_learning_rate, weight_decay=args.train.weight_decay)

    # ------------------
    #   Start training
    # ------------------
    model.train()
    for epoch in range(n_epoch):
        total_loss = 0
        print("on epoch: {}".format(epoch+1))
        for i, batch in enumerate(tqdm(train_data)):
            batch: Dict[str, Any]
            event_image = batch["event_volume"].to(device) # [B, 4, 480, 640]
            ground_truth_flow = batch["flow_gt"].to(device) # [B, 2, 480, 640]
            flow = model(event_image) # [B, 2, 480, 640]
            loss: torch.Tensor = compute_epe_error(flow, ground_truth_flow)
            print(f"batch {i} loss: {loss.item()}")
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if loss < 2.5:
                torch.save(model.state_dict(), "checkpoints/model_tmp.pth")
                print("tmp model saved!")


            total_loss += loss.item()
        print(f'Epoch {epoch+1}, Loss: {total_loss / len(train_data)}')
        
        current_time = time.strftime("%Y%m%d%H%M%S")
        model_path = f"checkpoints/model_{current_time}.pth"
        torch.save(model.state_dict(), model_path)
        print(f"Model saved to {model_path}")


    # ------------------
    #   Start predicting ()
    # ------------------
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    flow: torch.Tensor = torch.tensor([]).to(device)
    with torch.no_grad():
        print("start test")
        for batch in tqdm(test_data):
            batch: Dict[str, Any]
            event_image = batch["event_volume"].to(device)
            batch_flow = model(event_image) # [1, 2, 480, 640]
            flow = torch.cat((flow, batch_flow), dim=0)  # [N, 2, 480, 640]
        print("test done")
    # ------------------
    #  save submission
    # ------------------
    file_name = "submission"
    save_optical_flow_to_npy(flow, file_name)

In [11]:
from src.models.evflownet_my import EVFlowNetMy
from src.models.base import *

# model_path_loadのモデルをロード
model_path_load="/workspace/checkpoints/model_20240717025440.pth"
model = EVFlowNetMy(args.train).to(device)
model.load_state_dict(torch.load(model_path_load, map_location=device))
train_model(args, model)

on epoch: 1


  return F.conv2d(input, weight, bias, self.stride,
  0%|          | 0/252 [00:03<?, ?it/s]


KeyboardInterrupt: 

In [None]:
save_model(model)

Model saved to checkpoints/model_20240717052428.pth
start test


  return F.conv2d(input, weight, bias, self.stride,
100%|██████████| 97/97 [00:12<00:00,  7.97it/s]


test done
Submission saved
