In [1]:
# Heavily adapted from https://pytorch.org/tutorials/beginner/introyt/trainingyt.html and https://pytorch.org/tutorials/beginner/hyperparameter_tuning_tutorial.html

import sys
import os
import torch
import torch.nn as nn
import datetime as dt
import argparse
import subprocess
import ray.cloudpickle as pickle
import numpy as np
import tempfile
import pandas as pd

from utils.dataset import VideoDataset
from ray import tune, train
from ray.train import Checkpoint, get_checkpoint
from ray.tune.schedulers import ASHAScheduler
from models.cnn_lstm.cnn_lstm import CNN_LSTM
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from sklearn.metrics import f1_score
from pathlib import Path
from functools import partial

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

main_folder_path = os.getcwd()

MODEL_NAME = "cnn_lstm"
NOW = dt.datetime.now()
FILENAME = f"{NOW.strftime('%Y-%m-%d-%H-%M-%S')}"
SAVE_DIR = f"{main_folder_path}/models/cnn_lstm/saved_models"
DATA_FOLDER = "data"
INF = 100000000.
NUM_WORKERS = 8
NUM_CLASSES = 2
GPUS_PER_TRIAL = torch.cuda.device_count() if torch.cuda.is_available() else 0

timestamp = dt.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
writer = SummaryWriter(f"runs/rcnn_{timestamp}")

def train_model(
    config: dict,
    epochs: int,
):
    """
    Trains the model and saves the weights into a `.pt` file.

    Args:
        epochs (int): Number of epochs.
        filename (str): Filename to save the model to.
        writer (SummaryWriter): Writer for logs.
        config (dict): Ray Tune dictionary.

    Returns:
        None
    """
    writer = SummaryWriter(f"runs/cnn_lstm_{timestamp}")
    train_dataset = VideoDataset(
        root=f"{main_folder_path}/data/train", 
        clip_len=config["steps"]
    )
    train_loader = DataLoader(
        dataset=train_dataset, 
        batch_size=int(config["batch_size"]),
        num_workers=NUM_WORKERS
    )

    val_dataset = VideoDataset(
        root=f"{main_folder_path}/data/validation", 
        clip_len=config["steps"]
    )
    val_loader = DataLoader(
        dataset=val_dataset, 
        batch_size=int(config["batch_size"]),
        num_workers=NUM_WORKERS
    )

    model = CNN_LSTM(
        input_channels=int(config["input_channels"]),
        num_cnn_layers=int(config["num_cnn_layers"]),
        num_kernels=int(config["num_kernels"]),
        kernel_size=int(config["kernel_size"]),
        stride=1,
        padding="same",
        dropout_prob=float(config["dropout_prob"]),
        bias=False,
        num_lstm_layers=int(config["num_lstm_layers"]),
        hidden_size=int(config["hidden_size"]),
        num_classes=NUM_CLASSES,
        bidirectional=bool(config["bidirectional"]),
        input_shape=(224, 224),
        steps=int(config["steps"])
    )

    print(device)
    model = model.to(device)

    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"])

    if torch.cuda.is_available() and torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    running_loss = 0.
    last_loss = 0.
    best_vloss = INF

    checkpoint = get_checkpoint()
    if checkpoint:
        with checkpoint.as_directory() as checkpoint_dir:
            data_path = Path(checkpoint_dir) / "data.pkl"
            with open(data_path, "rb") as fp:
                checkpoint_state = pickle.load(fp)
            start_epoch = checkpoint_state["epoch"]
            model.load_state_dict(checkpoint_state["net_state_dict"])
            optimizer.load_state_dict(checkpoint_state["optimizer_state_dict"])
    else:
        start_epoch = 0

    print(f"Starting training at epoch: {start_epoch}")
    for epoch in tqdm(range(start_epoch, epochs)):
        print(f"Epoch: {epoch}")
        model.train()
        collected_labels, collected_predictions = [], []
        for i, data in tqdm(enumerate(train_loader)):
            vid_inputs, labels = data["video"].to(device), data["target"].to(device)

            optimizer.zero_grad()
            output = model(vid_inputs)
            loss = loss_fn(output, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            collected_labels.append(labels.cpu())
            collected_predictions.append(output.argmax(dim=1).cpu())

            if i % 10 == 9:
                last_loss = running_loss / 10
                epoch_f1 = f1_score(torch.cat(collected_labels), torch.cat(collected_predictions), average="weighted")
                print(f"Batch: {i + 1}, Loss: {last_loss}, F1 score: {epoch_f1}")
                tb_x = epoch * len(train_loader) + i + 1

                writer.add_scalar("Loss/train", last_loss, tb_x)
                writer.add_scalar("F1 score/train", epoch_f1, tb_x)

                running_loss = 0.

        model.eval()
        with torch.no_grad():
            collected_labels, collected_predictions = [], []
            for i, vdata in enumerate(val_loader):
                vid_inputs, labels = vdata["video"].to(device), vdata["target"].to(device)
                output = model(vid_inputs)
                loss = loss_fn(output, labels)
                running_loss += loss.item()

                collected_labels.append(labels.cpu())
                collected_predictions.append(output.argmax(dim=1).cpu())
            val_f1 = f1_score(torch.cat(collected_labels), torch.cat(collected_predictions), average="weighted")
            print(f"Validation Loss: {loss.item()}, Validation F1 score: {val_f1}")

        avg_vloss = running_loss / (i + 1)
        print(f"Train Loss: {last_loss}, Val Loss: {avg_vloss}")

        writer.add_scalars("Training vs Validation Loss",
                           {"Train": last_loss, "Validation": avg_vloss},
                           epoch + 1)
        writer.add_scalars("Training vs Validation F1 Score",
                           {"Train": epoch_f1, "Validation": val_f1},
                           epoch + 1)
        writer.flush()

        if avg_vloss < best_vloss:
            best_vloss = avg_vloss


        checkpoint_data = {
            "epoch": epoch,
            "net_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict()
        }

        with tempfile.TemporaryDirectory() as checkpoint_dir:
            data_path = Path(checkpoint_dir) / "data.pkl"
            with open(data_path, "wb")as fp:
                pickle.dump(checkpoint_data, fp)
            
            checkpoint = Checkpoint.from_directory(checkpoint_dir)
            train.report({
                "loss": avg_vloss,
                "f1": val_f1
            }, checkpoint=checkpoint)
    
    print("Finished training")

  from .autonotebook import tqdm as notebook_tqdm
2024-11-06 02:52:22,353	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2024-11-06 02:52:23,127	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2024-11-06 02:52:23.286101: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-11-06 02:52:23.294725: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1730832743.304311  102377 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to re

In [2]:
config = {
    "input_channels": 6,
    "num_cnn_layers": 8,
    "num_kernels": 32,
    "kernel_size": 10,
    "dropout_prob": 0.25,
    "num_lstm_layers":5,
    "hidden_size": 32,
    "bidirectional": False,
    "steps": 100,
    "batch_size": 1,
    "lr": 1e-5
}

In [3]:
os.getcwd()

'/home/wilsonwid/github-repos/dsa4266-project'

In [None]:
train_model(config, 3)



cuda
Starting training at epoch: 0


  0%|          | 0/3 [00:00<?, ?it/s]

Epoch: 0


  return F.conv3d(


Batch: 10, Loss: 0.6946468889713288, F1 score: 0.18461538461538463




Batch: 20, Loss: 0.693981921672821, F1 score: 0.27199999999999996




Batch: 30, Loss: 0.6923047840595246, F1 score: 0.391382687034861




Batch: 40, Loss: 0.6910820543766022, F1 score: 0.3779580797836376




Batch: 50, Loss: 0.6937579989433289, F1 score: 0.35498901098901103




Batch: 60, Loss: 0.6930661618709564, F1 score: 0.3674074074074074




Batch: 70, Loss: 0.6918929934501648, F1 score: 0.404063492063492




Batch: 80, Loss: 0.6917422473430633, F1 score: 0.4230769230769231




Batch: 90, Loss: 0.6926373183727265, F1 score: 0.405431018771642




Batch: 100, Loss: 0.6938655734062195, F1 score: 0.4146216835899116




Batch: 110, Loss: 0.6927374303340912, F1 score: 0.42746344564526373




Batch: 120, Loss: 0.6920586347579956, F1 score: 0.43163289075238764




Batch: 130, Loss: 0.6956316947937011, F1 score: 0.4077453790049887




Batch: 140, Loss: 0.6942847549915314, F1 score: 0.40822981366459626




Batch: 150, Loss: 0.6927455484867096, F1 score: 0.4124




Batch: 160, Loss: 0.6937171518802643, F1 score: 0.41173864894795126




Batch: 170, Loss: 0.6954814553260803, F1 score: 0.4023456284437533




Batch: 180, Loss: 0.6921037495136261, F1 score: 0.41190185617815595




Batch: 190, Loss: 0.691452544927597, F1 score: 0.41481096686862834




Batch: 200, Loss: 0.6931620240211487, F1 score: 0.41738051974272444




Batch: 210, Loss: 0.692848163843155, F1 score: 0.42174024539393223




Batch: 220, Loss: 0.6941696107387543, F1 score: 0.4193312944166187




Batch: 230, Loss: 0.6917730867862701, F1 score: 0.42916331363667715




Batch: 240, Loss: 0.6945922493934631, F1 score: 0.428125




Batch: 250, Loss: 0.6923891067504883, F1 score: 0.43375903614457834




Batch: 260, Loss: 0.6929994344711303, F1 score: 0.4365610859728507




Batch: 270, Loss: 0.6916969299316407, F1 score: 0.4465619366469225




Batch: 280, Loss: 0.6923664331436157, F1 score: 0.453125




Batch: 290, Loss: 0.6922597467899323, F1 score: 0.454472066924046




Batch: 300, Loss: 0.6922758400440217, F1 score: 0.4629384576190996




Batch: 310, Loss: 0.692800772190094, F1 score: 0.4564838709677419




Batch: 320, Loss: 0.6935778319835663, F1 score: 0.45494804973599046




Batch: 330, Loss: 0.6945893824100494, F1 score: 0.44934088165976127




Batch: 340, Loss: 0.6937411010265351, F1 score: 0.4420540903983792




Batch: 350, Loss: 0.692115741968155, F1 score: 0.4451440243444912




Batch: 360, Loss: 0.6935293257236481, F1 score: 0.4482365848960209




Batch: 370, Loss: 0.6935611844062806, F1 score: 0.4474660822428977




Batch: 380, Loss: 0.6940670073032379, F1 score: 0.44654764831590293




Batch: 390, Loss: 0.692977923154831, F1 score: 0.44936012462287783




Batch: 400, Loss: 0.6954793751239776, F1 score: 0.440990014785409




Batch: 410, Loss: 0.6955307185649872, F1 score: 0.43338694795738614




Batch: 420, Loss: 0.6916327893733978, F1 score: 0.4446521880958025




Batch: 430, Loss: 0.6931079626083374, F1 score: 0.4532862013212786




Batch: 440, Loss: 0.6931180536746979, F1 score: 0.45388048257106445




Batch: 450, Loss: 0.6933177828788757, F1 score: 0.45150759057121226




Batch: 460, Loss: 0.6935235857963562, F1 score: 0.4492400756143668




Batch: 470, Loss: 0.6940839529037476, F1 score: 0.4453341491335412




Batch: 480, Loss: 0.6930683612823486, F1 score: 0.4444444444444445




Batch: 490, Loss: 0.6939326584339142, F1 score: 0.44523865192558204




Batch: 500, Loss: 0.6926870346069336, F1 score: 0.4460021436227224




Batch: 510, Loss: 0.6922154664993286, F1 score: 0.4466513878278584




Batch: 520, Loss: 0.6927872121334075, F1 score: 0.4461538461538462




Batch: 530, Loss: 0.6933934569358826, F1 score: 0.4494150943396226




Batch: 540, Loss: 0.6932486355304718, F1 score: 0.4535797875661837


