0. Install Package

In [None]:
!pip -q install snntorch remotezip tqdm opencv-python-headless

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/125.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m125.6/125.6 kB[0m [31m14.1 MB/s[0m eta [36m0:00:00[0m
[?25h

1. Download UCF-101 and Unzip

In [None]:
import random, pathlib, collections
from tqdm import tqdm
import remotezip as rz

# -----------------------------
# Download subset (mimic your notebook style)
# -----------------------------
URL = "https://storage.googleapis.com/thumos14_files/UCF101_videos.zip"

def get_class(fname: str) -> str:
    # notebook logic: v_ApplyEyeMakeup_g01_c01.avi -> ApplyEyeMakeup
    return fname.split("_")[-3]

def list_video_files(zip_url: str):
    files = []
    with rz.RemoteZip(zip_url) as z:
        for info in z.infolist():
            fn = info.filename
            # keep only .avi and ignore directory entries
            if fn.endswith(".avi") and ("/" in fn):
                files.append(fn)
    return files

def get_files_per_class(files):
    files_for_class = collections.defaultdict(list)
    for fn in files:
        cls = get_class(pathlib.Path(fn).name)  # use basename (safe)
        files_for_class[cls].append(fn)
    return files_for_class

def split_class_lists(files_for_class, count):
    split_files = []
    remainder = {}
    for cls in files_for_class:
        split_files.extend(files_for_class[cls][:count])
        remainder[cls] = files_for_class[cls][count:]
    return split_files, remainder

def download_from_zip(zip_url, to_dir: pathlib.Path, file_names):
    to_dir.mkdir(parents=True, exist_ok=True)
    with rz.RemoteZip(zip_url) as z:
        for fn in tqdm(file_names, desc=f"extract -> {to_dir.name}", leave=False):
            cls = get_class(pathlib.Path(fn).name)
            out_cls_dir = to_dir / cls
            out_cls_dir.mkdir(parents=True, exist_ok=True)

            # extract keeps original internal path; we rename to basename
            z.extract(fn, str(out_cls_dir))
            extracted = out_cls_dir / fn
            extracted = extracted if extracted.exists() else next(out_cls_dir.glob("**/*.avi"))
            out_path = out_cls_dir / pathlib.Path(fn).name
            if extracted != out_path:
                extracted.rename(out_path)

def download_ucf101_subset(zip_url, num_classes, splits, download_dir: pathlib.Path, seed=0):
    """
    splits example: {"train": 30, "val": 10, "test": 10}  (per class counts)
    """
    random.seed(seed)

    files = list_video_files(zip_url)
    files_for_class = get_files_per_class(files)

    # choose classes (deterministic)
    classes = sorted(files_for_class.keys())[:num_classes]
    files_for_class = {c: files_for_class[c] for c in classes}

    # shuffle per class
    for c in classes:
        random.shuffle(files_for_class[c])

    dirs = {}
    for split_name, per_class_count in splits.items():
        split_dir = download_dir / split_name
        split_files, files_for_class = split_class_lists(files_for_class, per_class_count)
        download_from_zip(zip_url, split_dir, split_files)
        dirs[split_name] = split_dir

    return dirs

# --------- run ----------
download_dir = pathlib.Path("./UCF101_subset")
subset_paths = download_ucf101_subset(
    URL,
    num_classes=10,
    splits={"train": 30, "val": 10, "test": 10},
    download_dir=download_dir,
    seed=0
)




2. Rate Encoding

In [None]:
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import snntorch as snn
from snntorch import spikegen
import random
import pathlib

# =========================================================
# Config
# =========================================================
device = "cuda" if torch.cuda.is_available() else "cpu"

batch_size = 8
num_epochs = 80
lr = 2e-3

T = 16
rate_scale = 1.0
tau = 2.0
beta = torch.exp(torch.tensor(-1.0 / tau)).item()

num_classes = 10
H, W = 112, 112

# =========================================================
# Video Utilities
# =========================================================
def resize_with_pad(frame_bgr, out_hw=(112,112)):
    out_h, out_w = out_hw
    h, w = frame_bgr.shape[:2]
    if h == 0 or w == 0:
        return np.zeros((out_h, out_w, 3), dtype=np.uint8)

    scale = min(out_w / w, out_h / h)
    nw, nh = int(w * scale), int(h * scale)
    resized = cv2.resize(frame_bgr, (nw, nh), interpolation=cv2.INTER_AREA)

    canvas = np.zeros((out_h, out_w, 3), dtype=np.uint8)
    top = (out_h - nh) // 2
    left = (out_w - nw) // 2
    canvas[top:top+nh, left:left+nw] = resized
    return canvas

def frames_from_video(video_path, n_frames, out_hw=(112,112), frame_step=2, training=True):
    cap = cv2.VideoCapture(str(video_path))
    length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    need_len = 1 + (n_frames - 1) * frame_step
    if length <= 0 or need_len > length:
        start = 0
    else:
        max_start = length - need_len
        start = random.randint(0, max_start) if training else 0

    cap.set(cv2.CAP_PROP_POS_FRAMES, start)

    frames = []
    ret, frame = cap.read()
    if not ret:
        cap.release()
        return np.zeros((n_frames, out_hw[0], out_hw[1], 3), dtype=np.uint8)

    frames.append(resize_with_pad(frame, out_hw))

    for _ in range(n_frames - 1):
        for _ in range(frame_step):
            ret, frame = cap.read()
        if ret:
            frames.append(resize_with_pad(frame, out_hw))
        else:
            frames.append(np.zeros_like(frames[0]))

    cap.release()
    frames = np.stack(frames, axis=0)[..., ::-1].copy()  # BGR -> RGB, make contiguous
    return frames

# =========================================================
# Dataset
# =========================================================
IM_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1)
IM_STD  = torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1)

class UCFSubsetDataset(Dataset):
    def __init__(self, split_dir, n_frames=16, frame_step=2, training=False, class_to_idx=None):
        self.split_dir = pathlib.Path(split_dir)
        self.n_frames = n_frames
        self.frame_step = frame_step
        self.training = training

        self.video_paths = sorted(self.split_dir.glob("*/*.avi"))
        self.class_names = sorted({p.parent.name for p in self.video_paths})

        if class_to_idx is None:
            self.class_to_idx = {c:i for i,c in enumerate(self.class_names)}
        else:
            self.class_to_idx = class_to_idx

        self.labels = [self.class_to_idx[p.parent.name] for p in self.video_paths]

    def __len__(self):
        return len(self.video_paths)

    def __getitem__(self, idx):
        vp = self.video_paths[idx]
        y = self.labels[idx]

        frames = frames_from_video(vp, self.n_frames, (H,W), self.frame_step, self.training)
        x = torch.from_numpy(frames).permute(0,3,1,2).float() / 255.0
        x = (x - IM_MEAN) / IM_STD
        return x, torch.tensor(y, dtype=torch.long)

# =========================================================
# Rate Encoder
# =========================================================
def rate_encode_video(video_btchw, rate_scale=1.0):
    """
    video_btchw: [B,T,C,H,W]
    return: [T,B,C,H,W]
    """
    B, T_, C, H_, W_ = video_btchw.shape
    p = torch.tanh(video_btchw).add(1).mul(0.5)
    p = torch.clamp(p * rate_scale, 0.0, 1.0)

    spk_list = []
    for t in range(T_):
        spk_t = spikegen.rate(p[:, t], num_steps=1)
        spk_list.append(spk_t[0])
    return torch.stack(spk_list, dim=0)

# =========================================================
# Conv3D SNN
# =========================================================
class Conv3DSNN(nn.Module):
    def __init__(self, beta=0.95, num_classes=10):
        super().__init__()

        self.conv1 = nn.Conv3d(3, 32, 3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm3d(32)
        self.lif1 = snn.Leaky(beta=beta)

        self.conv2 = nn.Conv3d(32, 64, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm3d(64)
        self.lif2 = snn.Leaky(beta=beta)

        self.conv3 = nn.Conv3d(64, 128, 3, padding=1, bias=False)
        self.bn3 = nn.BatchNorm3d(128)
        self.lif3 = snn.Leaky(beta=beta)

        self.pool = nn.MaxPool3d((1,2,2))
        self.gap = nn.AdaptiveAvgPool3d((None,1,1))

        self.fc1 = nn.Linear(128, 256)
        self.lif4 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(256, num_classes)
        self.lif5 = snn.Leaky(beta=beta)

    def lif_time(self, x, lif, mem):
        B,C,T_,H_,W_ = x.shape
        spk_list=[]
        for t in range(T_):
            spk_t, mem = lif(x[:,:,t], mem)
            spk_list.append(spk_t)
        return torch.stack(spk_list, dim=2), mem

    def forward(self, spk_in):
        x = spk_in.permute(1,2,0,3,4)  # B,C,T,H,W

        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()
        mem4 = self.lif4.init_leaky()
        mem5 = self.lif5.init_leaky()

        x = self.conv1(x); x = self.bn1(x); x,mem1=self.lif_time(x,self.lif1,mem1); x=self.pool(x)
        x = self.conv2(x); x = self.bn2(x); x,mem2=self.lif_time(x,self.lif2,mem2); x=self.pool(x)
        x = self.conv3(x); x = self.bn3(x); x,mem3=self.lif_time(x,self.lif3,mem3); x=self.pool(x)

        x = self.gap(x).squeeze(-1).squeeze(-1)  # B,128,T

        spk_out=[]
        for t in range(x.shape[2]):
            h = self.fc1(x[:,:,t])
            spk4, mem4 = self.lif4(h, mem4)
            o = self.fc2(spk4)
            spk5, mem5 = self.lif5(o, mem5)
            spk_out.append(spk5)

        return torch.stack(spk_out, dim=0)

# =========================================================
# DataLoaders (subset_paths must exist)
# =========================================================
train_ds = UCFSubsetDataset(subset_paths["train"], T, training=True)
class_to_idx = train_ds.class_to_idx
val_ds = UCFSubsetDataset(subset_paths["val"], T, training=False, class_to_idx=class_to_idx)
test_ds = UCFSubsetDataset(subset_paths["test"], T, training=False, class_to_idx=class_to_idx)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=batch_size)
test_loader = DataLoader(test_ds, batch_size=batch_size)

# =========================================================
# Train
# =========================================================
model = Conv3DSNN(beta=beta, num_classes=num_classes).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

@torch.no_grad()
def evaluate(loader):
    model.eval()
    correct,total=0,0
    for x,y in loader:
        x,y=x.to(device),y.to(device)
        spk_in = rate_encode_video(x, rate_scale)
        spk_out = model(spk_in)
        logits = spk_out.sum(0)
        pred = logits.argmax(1)
        correct += (pred==y).sum().item()
        total += y.numel()
    return correct/total

for epoch in range(1,num_epochs+1):
    model.train()
    correct,total,loss_sum=0,0,0
    for x,y in train_loader:
        x,y=x.to(device),y.to(device)
        spk_in = rate_encode_video(x, rate_scale)
        spk_out = model(spk_in)
        logits = spk_out.sum(0)
        loss = F.cross_entropy(logits,y)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)
        optimizer.step()

        loss_sum += loss.item()*y.size(0)
        correct += (logits.argmax(1)==y).sum().item()
        total += y.size(0)

    scheduler.step()
    val_acc = evaluate(val_loader)

    print(f"Epoch {epoch:02d} | "
          f"train loss {loss_sum/total:.4f} | "
          f"train acc {correct/total:.4f} | "
          f"val acc {val_acc:.4f}")

print("Test acc:", evaluate(test_loader))


Epoch 01 | train loss 2.3263 | train acc 0.1167 | val acc 0.1200
Epoch 02 | train loss 2.2885 | train acc 0.1267 | val acc 0.1700
Epoch 03 | train loss 2.2695 | train acc 0.2000 | val acc 0.1800
Epoch 04 | train loss 2.2058 | train acc 0.1967 | val acc 0.2100
Epoch 05 | train loss 2.1404 | train acc 0.2333 | val acc 0.2600
Epoch 06 | train loss 2.1736 | train acc 0.2167 | val acc 0.2600
Epoch 07 | train loss 2.0945 | train acc 0.2100 | val acc 0.2600
Epoch 08 | train loss 2.0617 | train acc 0.2767 | val acc 0.2300
Epoch 09 | train loss 2.0118 | train acc 0.2667 | val acc 0.2700
Epoch 10 | train loss 2.0755 | train acc 0.2133 | val acc 0.3200
Epoch 11 | train loss 1.9930 | train acc 0.2500 | val acc 0.2800
Epoch 12 | train loss 1.9515 | train acc 0.2667 | val acc 0.3400
Epoch 13 | train loss 1.9433 | train acc 0.2967 | val acc 0.2400
Epoch 14 | train loss 1.8943 | train acc 0.2700 | val acc 0.1900
Epoch 15 | train loss 1.9251 | train acc 0.3133 | val acc 0.3500
Epoch 16 | train loss 1.8

2. TTFS Encoding

In [None]:
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import snntorch as snn
from snntorch import spikegen
import random
import pathlib

# =========================================================
# Config
# =========================================================
device = "cuda" if torch.cuda.is_available() else "cpu"

batch_size = 8
num_epochs = 80
lr = 2e-3

# video sampling
n_frames   = 16     # clip length (also TTFS steps)
frame_step = 2      # 1 for more motion info, 2 saves compute
H, W       = 112, 112

# TTFS/latency params
tau_lat = 8.0
threshold_lat = 0.001
linear_lat = True
normalize_lat = True

# LIF params
tau = 2.0
beta = torch.exp(torch.tensor(-1.0 / tau)).item()

# CDF build params
cdf_bins = 512
cdf_max_batches = 200
cdf_agg = "max"   # "max" or "mean" to collapse video -> single image for histogram + encoding

# checkpoint (optional)
ckpt_path = "./ucf_subset_ttfs_conv3d_snn_ckpt.pth"

torch.backends.cudnn.benchmark = True


# =========================================================
# Video utils (fix negative strides by .copy())
# =========================================================
def resize_with_pad(frame_bgr, out_hw=(112,112)):
    out_h, out_w = out_hw
    h, w = frame_bgr.shape[:2]
    if h == 0 or w == 0:
        return np.zeros((out_h, out_w, 3), dtype=np.uint8)

    scale = min(out_w / w, out_h / h)
    nw, nh = int(w * scale), int(h * scale)
    resized = cv2.resize(frame_bgr, (nw, nh), interpolation=cv2.INTER_AREA)

    canvas = np.zeros((out_h, out_w, 3), dtype=np.uint8)
    top = (out_h - nh) // 2
    left = (out_w - nw) // 2
    canvas[top:top+nh, left:left+nw] = resized
    return canvas

def frames_from_video(video_path, n_frames, out_hw=(112,112), frame_step=2, training=True):
    cap = cv2.VideoCapture(str(video_path))
    if not cap.isOpened():
        return np.zeros((n_frames, out_hw[0], out_hw[1], 3), dtype=np.uint8)

    length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    need_len = 1 + (n_frames - 1) * frame_step
    if length <= 0 or need_len > length:
        start = 0
    else:
        max_start = length - need_len
        start = random.randint(0, max_start) if training else 0

    cap.set(cv2.CAP_PROP_POS_FRAMES, start)

    frames = []
    ret, frame = cap.read()
    if not ret:
        cap.release()
        return np.zeros((n_frames, out_hw[0], out_hw[1], 3), dtype=np.uint8)

    frames.append(resize_with_pad(frame, out_hw))

    for _ in range(n_frames - 1):
        for _ in range(frame_step):
            ret, frame = cap.read()
        if ret:
            frames.append(resize_with_pad(frame, out_hw))
        else:
            frames.append(np.zeros_like(frames[0]))

    cap.release()
    # BGR -> RGB; make contiguous to avoid negative strides
    frames = np.stack(frames, axis=0)[..., ::-1].copy()
    return frames  # [T,H,W,3] uint8 RGB


# =========================================================
# Dataset (subset_paths from your downloader)
# x is ImageNet-normalized: [T,C,H,W]
# =========================================================
IM_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
IM_STD  = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)

class UCFSubsetDataset(Dataset):
    def __init__(self, split_dir, n_frames=16, frame_step=2, training=False, class_to_idx=None):
        self.split_dir = pathlib.Path(split_dir)
        self.n_frames = n_frames
        self.frame_step = frame_step
        self.training = training

        self.video_paths = sorted(self.split_dir.glob("*/*.avi"))
        self.class_names = sorted({p.parent.name for p in self.video_paths})

        if class_to_idx is None:
            self.class_to_idx = {c:i for i,c in enumerate(self.class_names)}
        else:
            self.class_to_idx = class_to_idx

        self.labels = [self.class_to_idx[p.parent.name] for p in self.video_paths]

    def __len__(self):
        return len(self.video_paths)

    def __getitem__(self, idx):
        vp = self.video_paths[idx]
        y = self.labels[idx]

        frames = frames_from_video(vp, self.n_frames, (H,W), self.frame_step, self.training)  # [T,H,W,3]
        x = torch.from_numpy(frames).permute(0,3,1,2).float() / 255.0                          # [T,3,H,W]
        x = (x - IM_MEAN) / IM_STD
        return x, torch.tensor(y, dtype=torch.long)

def collate_btchw(batch):
    xs, ys = zip(*batch)
    x = torch.stack(xs, dim=0)  # [B,T,3,H,W]
    y = torch.stack(ys, dim=0)
    return x, y


# =========================================================
# CDF build (mimic your CIFAR CDF builder)
# =========================================================
@torch.no_grad()
def build_cdf_from_trainloader_ucf(train_loader, num_bins=512, max_batches=200, device="cuda", agg="max"):
    hist = torch.zeros(num_bins, device=device)
    mean = IM_MEAN.to(device)  # [1,3,1,1]
    std  = IM_STD.to(device)

    for i, (x_norm, _) in enumerate(train_loader):
        if max_batches is not None and i >= max_batches:
            break

        x_norm = x_norm.to(device, non_blocking=True)        # [B,T,3,H,W]
        x_raw = (x_norm * std + mean).clamp(0.0, 1.0)        # [B,T,3,H,W]

        if agg == "max":
            x_img = x_raw.max(dim=1).values                  # [B,3,H,W]
        elif agg == "mean":
            x_img = x_raw.mean(dim=1)
        else:
            raise ValueError("agg must be 'max' or 'mean'")

        v = x_img.flatten()
        hist += torch.histc(v, bins=num_bins, min=0.0, max=1.0)

    hist = hist / (hist.sum() + 1e-12)
    cdf = torch.cumsum(hist, dim=0)
    bin_edges = torch.linspace(0.0, 1.0, steps=num_bins + 1, device=device)
    return bin_edges, cdf


# =========================================================
# TTFS encoder (equalized latency) for UCF subset
# Output: [T_steps, B, 3, H, W]
# =========================================================
@torch.no_grad()
def ttfs_encode_ucf_equalized_latency(
    x_video_norm, T_steps, bin_edges, cdf,
    agg="max",
    normalize=True, linear=True, tau=8.0, threshold=0.001
):
    """
    x_video_norm: [B,T,3,H,W] normalized
    """
    device = x_video_norm.device
    mean = IM_MEAN.to(device)
    std  = IM_STD.to(device)

    x_raw = (x_video_norm * std + mean).clamp(0.0, 1.0)  # [B,T,3,H,W]

    if agg == "max":
        x_img = x_raw.max(dim=1).values                  # [B,3,H,W]
    elif agg == "mean":
        x_img = x_raw.mean(dim=1)
    else:
        raise ValueError("agg must be 'max' or 'mean'")

    nb = cdf.numel()
    idx = torch.bucketize(x_img, bin_edges[1:-1], right=False).clamp(0, nb - 1)
    u = cdf[idx].clamp(1e-4, 1.0 - 1e-4)

    spk = spikegen.latency(
        u,
        num_steps=T_steps,
        normalize=normalize,
        linear=linear,
        tau=tau,
        threshold=threshold
    )
    return spk  # [T_steps,B,3,H,W]


# =========================================================
# Conv3D SNN (TTFS input)
# - last layer outputs analog logits for stability
# =========================================================
class Conv3DSNN_TTFS(nn.Module):
    def __init__(self, beta=0.95, num_classes=10):
        super().__init__()

        self.conv1 = nn.Conv3d(3, 32, 3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm3d(32)
        self.lif1 = snn.Leaky(beta=beta)

        self.conv2 = nn.Conv3d(32, 64, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm3d(64)
        self.lif2 = snn.Leaky(beta=beta)

        self.conv3 = nn.Conv3d(64, 128, 3, padding=1, bias=False)
        self.bn3 = nn.BatchNorm3d(128)
        self.lif3 = snn.Leaky(beta=beta)

        self.pool = nn.MaxPool3d((1,2,2))
        self.gap = nn.AdaptiveAvgPool3d((None,1,1))

        self.fc1 = nn.Linear(128, 256)
        self.lif4 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(256, num_classes)  # analog logits

    def _lif_over_time_2d(self, x_bcthw, lif, mem):
        B,C,T,H,W = x_bcthw.shape
        spk_list = []
        for t in range(T):
            spk_t, mem = lif(x_bcthw[:, :, t], mem)  # [B,C,H,W]
            spk_list.append(spk_t)
        return torch.stack(spk_list, dim=2), mem

    def forward(self, spk_in):
        """
        spk_in: [T,B,3,H,W]
        return logits_rec: [T,B,num_classes]
        """
        x = spk_in.permute(1,2,0,3,4).contiguous()  # [B,3,T,H,W]

        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()
        mem4 = self.lif4.init_leaky()

        x = self.conv1(x); x = self.bn1(x); x, mem1 = self._lif_over_time_2d(x, self.lif1, mem1); x = self.pool(x)
        x = self.conv2(x); x = self.bn2(x); x, mem2 = self._lif_over_time_2d(x, self.lif2, mem2); x = self.pool(x)
        x = self.conv3(x); x = self.bn3(x); x, mem3 = self._lif_over_time_2d(x, self.lif3, mem3); x = self.pool(x)

        x = self.gap(x).squeeze(-1).squeeze(-1)  # [B,128,T]

        logits_list = []
        for t in range(x.shape[2]):
            h = self.fc1(x[:,:,t])
            spk4, mem4 = self.lif4(h, mem4)
            logits_t = self.fc2(spk4)
            logits_list.append(logits_t)

        return torch.stack(logits_list, dim=0)


# =========================================================
# Build loaders
# =========================================================
train_ds = UCFSubsetDataset(subset_paths["train"], n_frames=n_frames, frame_step=frame_step, training=True)
class_to_idx = train_ds.class_to_idx
val_ds   = UCFSubsetDataset(subset_paths["val"],   n_frames=n_frames, frame_step=frame_step, training=False, class_to_idx=class_to_idx)
test_ds  = UCFSubsetDataset(subset_paths["test"],  n_frames=n_frames, frame_step=frame_step, training=False, class_to_idx=class_to_idx)

num_classes = len(class_to_idx)
T_steps = train_ds.n_frames

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=2, pin_memory=True, collate_fn=collate_btchw)
val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, drop_last=False, num_workers=2, pin_memory=True, collate_fn=collate_btchw)
test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, drop_last=False, num_workers=2, pin_memory=True, collate_fn=collate_btchw)

print("num_classes =", num_classes, " | T_steps =", T_steps, " | train batches =", len(train_loader))


# =========================================================
# Build CDF once
# =========================================================
bin_edges, cdf = build_cdf_from_trainloader_ucf(
    train_loader, num_bins=cdf_bins, max_batches=cdf_max_batches, device=device, agg=cdf_agg
)
print("CDF built.")


# =========================================================
# Model + Optim
# =========================================================
model = Conv3DSNN_TTFS(beta=beta, num_classes=num_classes).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)


# =========================================================
# Resume (optional)
# =========================================================
start_epoch = 1
if os.path.exists(ckpt_path):
    ckpt = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(ckpt["model_state"])
    optimizer.load_state_dict(ckpt["optimizer_state"])
    scheduler.load_state_dict(ckpt["scheduler_state"])
    start_epoch = ckpt["epoch"] + 1
    print(f"Resumed from epoch {start_epoch}")


# =========================================================
# Eval
# =========================================================
@torch.no_grad()
def evaluate(loader):
    model.eval()
    correct, total, loss_sum = 0, 0, 0.0

    for x, y in loader:
        x = x.to(device, non_blocking=True)  # [B,T,3,H,W]
        y = y.to(device, non_blocking=True)

        spk_in = ttfs_encode_ucf_equalized_latency(
            x, T_steps=T_steps, bin_edges=bin_edges, cdf=cdf,
            agg=cdf_agg, normalize=normalize_lat, linear=linear_lat,
            tau=tau_lat, threshold=threshold_lat
        )  # [T,B,3,H,W]

        logits_rec = model(spk_in)       # [T,B,C]
        logits = logits_rec.mean(dim=0)  # [B,C]
        loss = F.cross_entropy(logits, y)

        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.numel()
        loss_sum += loss.item() * y.size(0)

    return loss_sum / max(total, 1), correct / max(total, 1)


# =========================================================
# Train
# =========================================================
for epoch in range(start_epoch, num_epochs + 1):
    model.train()
    running_loss, running_correct, running_total = 0.0, 0, 0

    for x, y in train_loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        spk_in = ttfs_encode_ucf_equalized_latency(
            x, T_steps=T_steps, bin_edges=bin_edges, cdf=cdf,
            agg=cdf_agg, normalize=normalize_lat, linear=linear_lat,
            tau=tau_lat, threshold=threshold_lat
        )

        logits_rec = model(spk_in)
        logits = logits_rec.mean(dim=0)
        loss = F.cross_entropy(logits, y)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        running_loss += loss.item() * y.size(0)
        running_correct += (logits.argmax(dim=1) == y).sum().item()
        running_total += y.size(0)

    scheduler.step()
    val_loss, val_acc = evaluate(val_loader)

    print(f"Epoch {epoch:02d} | "
          f"train loss {running_loss/running_total:.4f} | "
          f"train acc {running_correct/running_total:.4f} | "
          f"val loss {val_loss:.4f} | "
          f"val acc {val_acc:.4f}")

    torch.save({
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "scheduler_state": scheduler.state_dict(),
        "class_to_idx": class_to_idx,
        "T_steps": T_steps,
        "cdf_bins": cdf_bins,
        "cdf_agg": cdf_agg,
        "tau_lat": tau_lat,
        "threshold_lat": threshold_lat,
    }, ckpt_path)
    print(f"Saved checkpoint: {ckpt_path}")

test_loss, test_acc = evaluate(test_loader)
print(f"Test loss {test_loss:.4f} | Test acc {test_acc:.4f}")


num_classes = 10  | T_steps = 16  | train batches = 37
CDF built.
Resumed from epoch 31
Epoch 31 | train loss 1.5144 | train acc 0.3986 | val loss 1.4767 | val acc 0.4600
Saved checkpoint: ./ucf_subset_ttfs_conv3d_snn_ckpt.pth
Epoch 32 | train loss 1.5405 | train acc 0.4189 | val loss 1.4818 | val acc 0.4100
Saved checkpoint: ./ucf_subset_ttfs_conv3d_snn_ckpt.pth
Epoch 33 | train loss 1.5395 | train acc 0.4122 | val loss 1.4781 | val acc 0.4500
Saved checkpoint: ./ucf_subset_ttfs_conv3d_snn_ckpt.pth
Epoch 34 | train loss 1.5093 | train acc 0.4527 | val loss 1.4703 | val acc 0.4600
Saved checkpoint: ./ucf_subset_ttfs_conv3d_snn_ckpt.pth
Epoch 35 | train loss 1.5111 | train acc 0.4257 | val loss 1.4746 | val acc 0.4700
Saved checkpoint: ./ucf_subset_ttfs_conv3d_snn_ckpt.pth
Epoch 36 | train loss 1.5327 | train acc 0.4122 | val loss 1.4615 | val acc 0.4800
Saved checkpoint: ./ucf_subset_ttfs_conv3d_snn_ckpt.pth
Epoch 37 | train loss 1.5288 | train acc 0.4527 | val loss 1.4710 | val acc 0.

3. ISI Encoding

In [None]:
import os
import cv2
import numpy as np
import random
import pathlib

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import snntorch as snn

# =========================================================
# Config
# =========================================================
device = "cuda" if torch.cuda.is_available() else "cpu"

batch_size = 8
num_epochs = 80
lr = 2e-3

# video sampling
T_steps    = 16     # frames_per_clip (also encoder time steps)
frame_step = 2
H, W       = 112, 112

# ISI encoder
K_spikes  = 2
alpha_max = 2.0
eps = 1e-3
agg_mode = "max"     # "max" or "mean" to collapse clip -> image before ISI

# LIF
tau = 2.0
beta = torch.exp(torch.tensor(-1.0 / tau)).item()

# checkpoint paths
last_ckpt_path = "./ucf_subset_isi_last.pth"
best_ckpt_path = "./ucf_subset_isi_best.pth"

torch.backends.cudnn.benchmark = True


# =========================================================
# Video utils
# =========================================================
def resize_with_pad(frame_bgr, out_hw=(112,112)):
    out_h, out_w = out_hw
    h, w = frame_bgr.shape[:2]
    if h == 0 or w == 0:
        return np.zeros((out_h, out_w, 3), dtype=np.uint8)

    scale = min(out_w / w, out_h / h)
    nw, nh = int(w * scale), int(h * scale)
    resized = cv2.resize(frame_bgr, (nw, nh), interpolation=cv2.INTER_AREA)

    canvas = np.zeros((out_h, out_w, 3), dtype=np.uint8)
    top = (out_h - nh) // 2
    left = (out_w - nw) // 2
    canvas[top:top+nh, left:left+nw] = resized
    return canvas

def frames_from_video(video_path, n_frames, out_hw=(112,112), frame_step=2, training=True):
    cap = cv2.VideoCapture(str(video_path))
    if not cap.isOpened():
        return np.zeros((n_frames, out_hw[0], out_hw[1], 3), dtype=np.uint8)

    length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    need_len = 1 + (n_frames - 1) * frame_step
    if length <= 0 or need_len > length:
        start = 0
    else:
        max_start = length - need_len
        start = random.randint(0, max_start) if training else 0

    cap.set(cv2.CAP_PROP_POS_FRAMES, start)

    frames = []
    ret, frame = cap.read()
    if not ret:
        cap.release()
        return np.zeros((n_frames, out_hw[0], out_hw[1], 3), dtype=np.uint8)

    frames.append(resize_with_pad(frame, out_hw))

    for _ in range(n_frames - 1):
        for _ in range(frame_step):
            ret, frame = cap.read()
        if ret:
            frames.append(resize_with_pad(frame, out_hw))
        else:
            frames.append(np.zeros_like(frames[0]))

    cap.release()

    # BGR -> RGB and make contiguous (avoid negative strides)
    frames = np.stack(frames, axis=0)[..., ::-1].copy()
    return frames  # [T,H,W,3] uint8 RGB


# =========================================================
# Dataset
# =========================================================
IM_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
IM_STD  = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)

class UCFSubsetDataset(Dataset):
    def __init__(self, split_dir, n_frames=16, frame_step=2, training=False, class_to_idx=None):
        self.split_dir = pathlib.Path(split_dir)
        self.n_frames = n_frames
        self.frame_step = frame_step
        self.training = training

        self.video_paths = sorted(self.split_dir.glob("*/*.avi"))
        self.class_names = sorted({p.parent.name for p in self.video_paths})

        if class_to_idx is None:
            self.class_to_idx = {c:i for i,c in enumerate(self.class_names)}
        else:
            self.class_to_idx = class_to_idx

        self.labels = [self.class_to_idx[p.parent.name] for p in self.video_paths]

    def __len__(self):
        return len(self.video_paths)

    def __getitem__(self, idx):
        vp = self.video_paths[idx]
        y = self.labels[idx]

        frames = frames_from_video(vp, self.n_frames, (H,W), self.frame_step, self.training)
        x = torch.from_numpy(frames).permute(0,3,1,2).float() / 255.0   # [T,3,H,W]
        x = (x - IM_MEAN) / IM_STD                                       # ImageNet normalize
        return x, torch.tensor(y, dtype=torch.long)

def collate_btchw(batch):
    xs, ys = zip(*batch)
    x = torch.stack(xs, dim=0)  # [B,T,3,H,W]
    y = torch.stack(ys, dim=0)
    return x, y


# =========================================================
# ISI fixed-K strict encoder (same spirit as your snippet)
# =========================================================
@torch.no_grad()
def isi_fixedK_no_endcaps_strict(
    x_img_unit: torch.Tensor, T: int, K: int, alpha_max: float = 2.0, eps: float = 1e-3
) -> torch.Tensor:
    assert T >= 2 and K >= 1
    if K > T:
        raise ValueError(f"K={K} must satisfy K<=T={T}.")

    device = x_img_unit.device
    B = x_img_unit.size(0)
    x = x_img_unit.view(B, -1).clamp(0.0, 1.0)
    N = x.size(1)

    M = T
    j = torch.arange(M, device=device, dtype=torch.float32).view(1, 1, M)
    mid = (M - 1) / 2.0

    alpha = (x * 2.0 - 1.0) * alpha_max
    alpha = alpha.unsqueeze(-1)

    w = torch.exp(alpha * (j - mid))
    w = w / (w.sum(dim=-1, keepdim=True) + 1e-12)
    c = torch.cumsum(w, dim=-1)

    q = torch.linspace(eps, 1.0 - eps, steps=K, device=device, dtype=torch.float32)
    q = q.view(1, 1, K).expand(B, N, K)

    t_idx = torch.searchsorted(c, q).clamp(0, T - 1).long()
    t_idx, _ = torch.sort(t_idx, dim=-1)

    used = torch.zeros(B, N, T, device=device, dtype=torch.bool)
    t_fixed = torch.full_like(t_idx, -1)

    for k in range(K):
        tk = t_idx[..., k]
        free = ~used.gather(dim=2, index=tk.unsqueeze(-1)).squeeze(-1)
        t_fixed[..., k] = torch.where(free, tk, torch.full_like(tk, -1))
        if free.any():
            used[free] |= F.one_hot(tk[free], num_classes=T).bool()

    for k in range(K):
        need = (t_fixed[..., k] < 0)
        if not need.any():
            continue

        tk = t_idx[..., k].clone()
        avail = ~used
        ar = torch.arange(T, device=device).view(1, 1, T)

        forward_mask = avail & (ar >= tk.unsqueeze(-1))
        fwd_pos = forward_mask.float().argmax(dim=-1)
        fwd_exists = forward_mask.any(dim=-1)

        backward_mask = avail & (ar <= tk.unsqueeze(-1))
        rev = torch.flip(backward_mask, dims=[-1])
        bwd_pos_rev = rev.float().argmax(dim=-1)
        bwd_pos = (T - 1) - bwd_pos_rev
        bwd_exists = backward_mask.any(dim=-1)

        chosen = torch.where(fwd_exists, fwd_pos, bwd_pos).long()
        chosen = torch.where(fwd_exists | bwd_exists, chosen, torch.zeros_like(chosen))

        t_fixed[..., k] = torch.where(need, chosen, t_fixed[..., k])
        used[need] |= F.one_hot(chosen[need], num_classes=T).bool()

    spk_flat = torch.zeros(T, B, N, device=device, dtype=torch.float32)
    b_idx = torch.arange(B, device=device).view(B, 1, 1).expand(B, N, K)
    n_idx = torch.arange(N, device=device).view(1, N, 1).expand(B, N, K)
    spk_flat[t_fixed, b_idx, n_idx] = 1.0

    return spk_flat.view(T, B, *x_img_unit.shape[1:])  # [T,B,3,H,W]

@torch.no_grad()
def isi_encode_ucf_video(
    x_video_norm: torch.Tensor,  # [B,T,3,H,W] normalized
    T_steps: int,
    K: int,
    alpha_max: float = 2.0,
    eps: float = 1e-3,
    agg="max",
):
    mean = IM_MEAN.to(x_video_norm.device)
    std  = IM_STD.to(x_video_norm.device)
    x_raw = (x_video_norm * std + mean).clamp(0.0, 1.0)

    if agg == "max":
        x_img = x_raw.max(dim=1).values
    elif agg == "mean":
        x_img = x_raw.mean(dim=1)
    else:
        raise ValueError("agg must be 'max' or 'mean'")

    return isi_fixedK_no_endcaps_strict(x_img, T=T_steps, K=K, alpha_max=alpha_max, eps=eps)


# =========================================================
# Conv3D SNN (analog logits head)
# =========================================================
class Conv3DSNN_ISI(nn.Module):
    def __init__(self, beta=0.95, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv3d(3, 32, 3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm3d(32)
        self.lif1 = snn.Leaky(beta=beta)

        self.conv2 = nn.Conv3d(32, 64, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm3d(64)
        self.lif2 = snn.Leaky(beta=beta)

        self.conv3 = nn.Conv3d(64, 128, 3, padding=1, bias=False)
        self.bn3 = nn.BatchNorm3d(128)
        self.lif3 = snn.Leaky(beta=beta)

        self.pool = nn.MaxPool3d((1,2,2))
        self.gap = nn.AdaptiveAvgPool3d((None,1,1))

        self.fc1 = nn.Linear(128, 256)
        self.lif4 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(256, num_classes)

    def _lif_over_time_2d(self, x_bcthw, lif, mem):
        B,C,T,H,W = x_bcthw.shape
        spk_list=[]
        for t in range(T):
            spk_t, mem = lif(x_bcthw[:, :, t], mem)
            spk_list.append(spk_t)
        return torch.stack(spk_list, dim=2), mem

    def forward(self, spk_in):
        # spk_in: [T,B,3,H,W]
        x = spk_in.permute(1,2,0,3,4).contiguous()  # [B,3,T,H,W]

        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()
        mem4 = self.lif4.init_leaky()

        x = self.conv1(x); x = self.bn1(x); x, mem1 = self._lif_over_time_2d(x, self.lif1, mem1); x = self.pool(x)
        x = self.conv2(x); x = self.bn2(x); x, mem2 = self._lif_over_time_2d(x, self.lif2, mem2); x = self.pool(x)
        x = self.conv3(x); x = self.bn3(x); x, mem3 = self._lif_over_time_2d(x, self.lif3, mem3); x = self.pool(x)

        x = self.gap(x).squeeze(-1).squeeze(-1)  # [B,128,T]

        logits_list=[]
        for t in range(x.shape[2]):
            h = self.fc1(x[:,:,t])
            spk4, mem4 = self.lif4(h, mem4)
            logits_t = self.fc2(spk4)
            logits_list.append(logits_t)

        return torch.stack(logits_list, dim=0)  # [T,B,num_classes]


# =========================================================
# Build loaders from subset_paths (must exist)
# =========================================================
train_ds = UCFSubsetDataset(subset_paths["train"], n_frames=T_steps, frame_step=frame_step, training=True)
class_to_idx = train_ds.class_to_idx
val_ds   = UCFSubsetDataset(subset_paths["val"],   n_frames=T_steps, frame_step=frame_step, training=False, class_to_idx=class_to_idx)
test_ds  = UCFSubsetDataset(subset_paths["test"],  n_frames=T_steps, frame_step=frame_step, training=False, class_to_idx=class_to_idx)

num_classes = len(class_to_idx)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True,
                          num_workers=2, pin_memory=True, collate_fn=collate_btchw)
val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, drop_last=False,
                          num_workers=2, pin_memory=True, collate_fn=collate_btchw)
test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, drop_last=False,
                          num_workers=2, pin_memory=True, collate_fn=collate_btchw)

print(f"num_classes = {num_classes} | T_steps = {T_steps} | train batches = {len(train_loader)}")


# =========================================================
# Train / Eval + BEST checkpoint
# =========================================================
model = Conv3DSNN_ISI(beta=beta, num_classes=num_classes).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

start_epoch = 1
best_val = -1.0

# optional resume from LAST
if os.path.exists(last_ckpt_path):
    ckpt = torch.load(last_ckpt_path, map_location=device)
    model.load_state_dict(ckpt["model_state"])
    optimizer.load_state_dict(ckpt["optimizer_state"])
    scheduler.load_state_dict(ckpt["scheduler_state"])
    start_epoch = ckpt["epoch"] + 1
    best_val = ckpt.get("best_val", -1.0)
    print("Resumed from epoch", start_epoch, "best_val", best_val)

@torch.no_grad()
def evaluate(loader):
    model.eval()
    correct, total, loss_sum = 0, 0, 0.0
    for x, y in loader:
        x = x.to(device, non_blocking=True)  # [B,T,3,H,W]
        y = y.to(device, non_blocking=True)

        spk_in = isi_encode_ucf_video(
            x, T_steps=T_steps, K=K_spikes, alpha_max=alpha_max, eps=eps, agg=agg_mode
        )  # [T,B,3,H,W]

        logits_rec = model(spk_in)           # [T,B,C]
        logits = logits_rec.mean(dim=0)      # [B,C]
        loss = F.cross_entropy(logits, y)

        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.numel()
        loss_sum += loss.item() * y.size(0)

    return loss_sum / max(total, 1), correct / max(total, 1)

for epoch in range(start_epoch, num_epochs + 1):
    model.train()
    running_loss, running_correct, running_total = 0.0, 0, 0

    for x, y in train_loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        spk_in = isi_encode_ucf_video(
            x, T_steps=T_steps, K=K_spikes, alpha_max=alpha_max, eps=eps, agg=agg_mode
        )

        logits_rec = model(spk_in)
        logits = logits_rec.mean(dim=0)
        loss = F.cross_entropy(logits, y)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        running_loss += loss.item() * y.size(0)
        running_correct += (logits.argmax(dim=1) == y).sum().item()
        running_total += y.size(0)

    scheduler.step()
    val_loss, val_acc = evaluate(val_loader)

    print(f"Epoch {epoch:02d} | "
          f"train loss {running_loss/running_total:.4f} | "
          f"train acc {running_correct/running_total:.4f} | "
          f"val loss {val_loss:.4f} | "
          f"val acc {val_acc:.4f}")

    # -------- save LAST --------
    torch.save({
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "scheduler_state": scheduler.state_dict(),
        "best_val": best_val,
        "class_to_idx": class_to_idx,
        "T_steps": T_steps,
        "K_spikes": K_spikes,
        "alpha_max": alpha_max,
        "agg_mode": agg_mode,
    }, last_ckpt_path)

    # -------- save BEST (by val acc) --------
    if val_acc > best_val:
        best_val = val_acc
        torch.save({
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "scheduler_state": scheduler.state_dict(),
            "best_val": best_val,
            "class_to_idx": class_to_idx,
            "T_steps": T_steps,
            "K_spikes": K_spikes,
            "alpha_max": alpha_max,
            "agg_mode": agg_mode,
        }, best_ckpt_path)
        print(f"Saved BEST: epoch {epoch} | best_val={best_val:.4f} -> {best_ckpt_path}")


# =========================================================
# Test using BEST checkpoint (important)
# =========================================================
print("\nLoading BEST checkpoint for test...")
best = torch.load(best_ckpt_path, map_location=device)
model.load_state_dict(best["model_state"])

test_loss, test_acc = evaluate(test_loader)
print(f"BEST epoch {best['epoch']} | best_val {best['best_val']:.4f} | Test loss {test_loss:.4f} | Test acc {test_acc:.4f}")


num_classes = 10 | T_steps = 16 | train batches = 37
Epoch 01 | train loss 2.2887 | train acc 0.0946 | val loss 2.1836 | val acc 0.2200
Saved BEST: epoch 1 | best_val=0.2200 -> ./ucf_subset_isi_best.pth
Epoch 02 | train loss 2.1531 | train acc 0.1689 | val loss 2.0812 | val acc 0.1800
Epoch 03 | train loss 2.1422 | train acc 0.1824 | val loss 1.9947 | val acc 0.2400
Saved BEST: epoch 3 | best_val=0.2400 -> ./ucf_subset_isi_best.pth
Epoch 04 | train loss 2.0531 | train acc 0.2061 | val loss 1.9709 | val acc 0.2300
Epoch 05 | train loss 2.0429 | train acc 0.2095 | val loss 2.0509 | val acc 0.1700
Epoch 06 | train loss 2.0075 | train acc 0.2264 | val loss 1.8715 | val acc 0.2800
Saved BEST: epoch 6 | best_val=0.2800 -> ./ucf_subset_isi_best.pth
Epoch 07 | train loss 1.9223 | train acc 0.2635 | val loss 1.9380 | val acc 0.2200
Epoch 08 | train loss 1.9170 | train acc 0.2703 | val loss 2.3416 | val acc 0.1700
Epoch 09 | train loss 1.9060 | train acc 0.2601 | val loss 2.3667 | val acc 0.1800

4. TTFS-Phase

In [None]:
import os
import cv2
import numpy as np
import random
import pathlib

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import snntorch as snn

for p in ["./ucf_subset_ttfsphase_videoK4_last.pth",
          "./ucf_subset_ttfsphase_videoK4_best.pth"]:
    if os.path.exists(p):
        os.remove(p)
        print("Removed:", p)
# =========================================================
# Config
# =========================================================
device = "cuda" if torch.cuda.is_available() else "cpu"

batch_size = 8
num_epochs = 80
lr = 2e-3

# clip sampling
T_FRAMES   = 16       # frames_per_clip
frame_step = 2
H, W       = 112, 112

# encoder time axis
T_STEPS = 60
P = 3
phi0 = 0
assert 0 <= phi0 < P
M = int(((T_STEPS - 1 - phi0) // P) + 1)  # maxima bins

# "video-aware" spike budget
K_SPIKES = 4  # <=4 spikes per pixel (one per frame-group)

# ranking jitter
JITTER = 1e-4   # for stability on many ties

# LIF
tau = 2.0
beta = torch.exp(torch.tensor(-1.0 / tau)).item()

# ckpt
last_ckpt_path = "./ucf_subset_ttfsphase_videoK4_last.pth"
best_ckpt_path = "./ucf_subset_ttfsphase_videoK4_best.pth"

torch.backends.cudnn.benchmark = True


# =========================================================
# Video utils
# =========================================================
def resize_with_pad(frame_bgr, out_hw=(112,112)):
    out_h, out_w = out_hw
    h, w = frame_bgr.shape[:2]
    if h == 0 or w == 0:
        return np.zeros((out_h, out_w, 3), dtype=np.uint8)

    scale = min(out_w / w, out_h / h)
    nw, nh = int(w * scale), int(h * scale)
    resized = cv2.resize(frame_bgr, (nw, nh), interpolation=cv2.INTER_AREA)

    canvas = np.zeros((out_h, out_w, 3), dtype=np.uint8)
    top = (out_h - nh) // 2
    left = (out_w - nw) // 2
    canvas[top:top+nh, left:left+nw] = resized
    return canvas

def frames_from_video(video_path, n_frames, out_hw=(112,112), frame_step=2, training=True):
    cap = cv2.VideoCapture(str(video_path))
    if not cap.isOpened():
        return np.zeros((n_frames, out_hw[0], out_hw[1], 3), dtype=np.uint8)

    length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    need_len = 1 + (n_frames - 1) * frame_step
    if length <= 0 or need_len > length:
        start = 0
    else:
        max_start = length - need_len
        start = random.randint(0, max_start) if training else 0

    cap.set(cv2.CAP_PROP_POS_FRAMES, start)

    frames = []
    ret, frame = cap.read()
    if not ret:
        cap.release()
        return np.zeros((n_frames, out_hw[0], out_hw[1], 3), dtype=np.uint8)

    frames.append(resize_with_pad(frame, out_hw))
    for _ in range(n_frames - 1):
        for _ in range(frame_step):
            ret, frame = cap.read()
        if ret:
            frames.append(resize_with_pad(frame, out_hw))
        else:
            frames.append(np.zeros_like(frames[0]))

    cap.release()
    frames = np.stack(frames, axis=0)[..., ::-1].copy()  # BGR->RGB contiguous
    return frames  # [T,H,W,3] uint8


# =========================================================
# Dataset
# =========================================================
IM_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
IM_STD  = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)

class UCFSubsetDataset(Dataset):
    def __init__(self, split_dir, n_frames=16, frame_step=2, training=False, class_to_idx=None):
        self.split_dir = pathlib.Path(split_dir)
        self.n_frames = n_frames
        self.frame_step = frame_step
        self.training = training

        self.video_paths = sorted(self.split_dir.glob("*/*.avi"))
        self.class_names = sorted({p.parent.name for p in self.video_paths})

        if class_to_idx is None:
            self.class_to_idx = {c:i for i,c in enumerate(self.class_names)}
        else:
            self.class_to_idx = class_to_idx

        self.labels = [self.class_to_idx[p.parent.name] for p in self.video_paths]

    def __len__(self):
        return len(self.video_paths)

    def __getitem__(self, idx):
        vp = self.video_paths[idx]
        y = self.labels[idx]

        frames = frames_from_video(vp, self.n_frames, (H,W), self.frame_step, self.training)
        x = torch.from_numpy(frames).permute(0,3,1,2).float() / 255.0  # [Tf,3,H,W]
        x = (x - IM_MEAN) / IM_STD
        return x, torch.tensor(y, dtype=torch.long)

def collate_btchw(batch):
    xs, ys = zip(*batch)
    x = torch.stack(xs, dim=0)  # [B,Tf,3,H,W]
    y = torch.stack(ys, dim=0)
    return x, y


# =========================================================
# Helper: normalized -> unit [0,1]
# =========================================================
@torch.no_grad()
def video_to_unit_interval(x_video_norm: torch.Tensor) -> torch.Tensor:
    mean = IM_MEAN.to(x_video_norm.device)
    std  = IM_STD.to(x_video_norm.device)
    return (x_video_norm * std + mean).clamp(0.0, 1.0)


# =========================================================
# Video-aware TTFS-Phase (Method B), K=4 spikes/pixel EXACT
# Group-reduce scores within each frame-group, then rank-balance once per group
# =========================================================
@torch.no_grad()
def ttfs_phase_rank_balance_per_frame_k4(
    x_video_norm: torch.Tensor,  # [B,Tf,3,H,W] normalized
    T: int,
    P: int,
    phi0: int,
    K: int = 4,
    jitter: float = 1e-4,
    group_reduce: str = "mean",  # "mean" or "max"
) -> torch.Tensor:
    """
    Video-aware TTFS-Phase (Method B), spike budget K=4:
      - split frames into K groups (Tf must be divisible by K)
      - within each group, reduce scores across frames (mean/max) -> one score per pixel
      - do rank-balance ONCE per group -> 1 phase-locked spike per pixel per group

    Guarantees:
      - STRICT phase-lock: t = phi0 + k*P
      - EXACT K spikes per pixel total (one per group), binary spikes
    Returns:
      spk: [T, B, 3, H, W]
    """
    assert K == 4, "this helper is specialized for K=4 (easy to generalize later)"
    assert 0 <= phi0 < P
    device = x_video_norm.device
    B, Tf, C, H_, W_ = x_video_norm.shape
    N = C * H_ * W_

    # number of maxima bins
    M = int(((T - 1 - phi0) // P) + 1)

    # map normalized -> [0,1]
    x_unit = video_to_unit_interval(x_video_norm)    # [B,Tf,3,H,W]
    x_flat = x_unit.reshape(B, Tf, -1)               # [B,Tf,N]

    group_size = Tf // K
    assert Tf % K == 0, "Tf must be divisible by K."

    spk = torch.zeros(T, B, N, device=device, dtype=torch.float32)

    b_idx = torch.arange(B, device=device).view(B, 1).expand(B, N)
    n_idx = torch.arange(N, device=device).view(1, N).expand(B, N)

    for g in range(K):
        f0 = g * group_size
        f1 = (g + 1) * group_size

        # group score aggregation: [B,N]
        if group_reduce == "mean":
            score = x_flat[:, f0:f1, :].mean(dim=1)
        elif group_reduce == "max":
            score = x_flat[:, f0:f1, :].max(dim=1).values
        else:
            raise ValueError("group_reduce must be 'mean' or 'max'")

        # break ties for stable ranks
        if jitter and jitter > 0:
            score = score + jitter * torch.randn_like(score)

        # rank: 0..N-1 (0=brightest)
        order = torch.argsort(score, dim=1, descending=True)  # [B,N]
        inv_rank = torch.empty_like(order)
        inv_rank.scatter_(1, order, torch.arange(N, device=device).view(1, N).expand(B, N))

        # rank -> maxima bin k in [0..M-1] (uniform occupancy)
        kbin = torch.floor(inv_rank.float() * M / float(N)).long().clamp(0, M - 1)  # [B,N]

        # phase-lock time
        t = (phi0 + kbin * P).long()  # [B,N]

        # exactly one spike per pixel for this group
        spk[t, b_idx, n_idx] = 1.0

    return spk.view(T, B, C, H_, W_)

# =========================================================
# Conv3D SNN (analog logits head)
# =========================================================
class Conv3DSNN_TTFSPhase(nn.Module):
    def __init__(self, beta=0.95, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv3d(3, 32, 3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm3d(32)
        self.lif1 = snn.Leaky(beta=beta)

        self.conv2 = nn.Conv3d(32, 64, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm3d(64)
        self.lif2 = snn.Leaky(beta=beta)

        self.conv3 = nn.Conv3d(64, 128, 3, padding=1, bias=False)
        self.bn3 = nn.BatchNorm3d(128)
        self.lif3 = snn.Leaky(beta=beta)

        self.pool = nn.MaxPool3d((1,2,2))
        self.gap = nn.AdaptiveAvgPool3d((None,1,1))

        self.fc1 = nn.Linear(128, 256)
        self.lif4 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(256, num_classes)

    def _lif_over_time_2d(self, x_bcthw, lif, mem):
        B,C,T,H,W = x_bcthw.shape
        spk_list=[]
        for t in range(T):
            spk_t, mem = lif(x_bcthw[:, :, t], mem)
            spk_list.append(spk_t)
        return torch.stack(spk_list, dim=2), mem

    def forward(self, spk_in):
        # spk_in: [T,B,3,H,W]
        x = spk_in.permute(1,2,0,3,4).contiguous()  # [B,3,T,H,W]

        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()
        mem4 = self.lif4.init_leaky()

        x = self.conv1(x); x = self.bn1(x); x, mem1 = self._lif_over_time_2d(x, self.lif1, mem1); x = self.pool(x)
        x = self.conv2(x); x = self.bn2(x); x, mem2 = self._lif_over_time_2d(x, self.lif2, mem2); x = self.pool(x)
        x = self.conv3(x); x = self.bn3(x); x, mem3 = self._lif_over_time_2d(x, self.lif3, mem3); x = self.pool(x)

        x = self.gap(x).squeeze(-1).squeeze(-1)  # [B,128,T]

        logits_list=[]
        for t in range(x.shape[2]):
            h = self.fc1(x[:,:,t])
            spk4, mem4 = self.lif4(h, mem4)
            logits_t = self.fc2(spk4)
            logits_list.append(logits_t)

        return torch.stack(logits_list, dim=0)  # [T,B,num_classes]


# =========================================================
# Build loaders from subset_paths (must exist)
# =========================================================
train_ds = UCFSubsetDataset(subset_paths["train"], n_frames=T_FRAMES, frame_step=frame_step, training=True)
class_to_idx = train_ds.class_to_idx
val_ds   = UCFSubsetDataset(subset_paths["val"],   n_frames=T_FRAMES, frame_step=frame_step, training=False, class_to_idx=class_to_idx)
test_ds  = UCFSubsetDataset(subset_paths["test"],  n_frames=T_FRAMES, frame_step=frame_step, training=False, class_to_idx=class_to_idx)

num_classes = len(class_to_idx)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=2, pin_memory=True, collate_fn=collate_btchw)
val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, drop_last=False, num_workers=2, pin_memory=True, collate_fn=collate_btchw)
test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, drop_last=False, num_workers=2, pin_memory=True, collate_fn=collate_btchw)

print(f"num_classes={num_classes} | clip_frames={T_FRAMES} | encoder_T={T_STEPS} | P={P} | M={M} | K={K_SPIKES} (<=K per pixel)")


# =========================================================
# Model + Optim + Resume + BEST
# =========================================================
model = Conv3DSNN_TTFSPhase(beta=beta, num_classes=num_classes).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

start_epoch = 1
best_val = -1.0

if os.path.exists(last_ckpt_path):
    ckpt = torch.load(last_ckpt_path, map_location=device)
    model.load_state_dict(ckpt["model_state"])
    optimizer.load_state_dict(ckpt["optimizer_state"])
    scheduler.load_state_dict(ckpt["scheduler_state"])
    start_epoch = ckpt["epoch"] + 1
    best_val = ckpt.get("best_val", -1.0)
    print("Resumed from epoch", start_epoch, "best_val", best_val)

@torch.no_grad()
def evaluate(loader):
    model.eval()
    correct, total, loss_sum = 0, 0, 0.0
    for x, y in loader:
        x = x.to(device, non_blocking=True)  # [B,Tf,3,H,W]
        y = y.to(device, non_blocking=True)

        spk_in = ttfs_phase_rank_balance_per_frame_k4(
        x, T=T_STEPS, P=P, phi0=phi0, K=K_SPIKES, jitter=JITTER, group_reduce="max"
        )  # [T,B,3,H,W]

        logits_rec = model(spk_in)      # [T,B,C]
        logits = logits_rec.mean(0)     # [B,C]
        loss = F.cross_entropy(logits, y)

        pred = logits.argmax(1)
        correct += (pred == y).sum().item()
        total += y.numel()
        loss_sum += loss.item() * y.size(0)

    return loss_sum / max(total, 1), correct / max(total, 1)

for epoch in range(start_epoch, num_epochs + 1):
    model.train()
    running_loss, running_correct, running_total = 0.0, 0, 0

    for x, y in train_loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        spk_in = ttfs_phase_rank_balance_per_frame_k4(
            x, T=T_STEPS, P=P, phi0=phi0, K=K_SPIKES, jitter=JITTER
        )

        logits_rec = model(spk_in)
        logits = logits_rec.mean(0)
        loss = F.cross_entropy(logits, y)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        running_loss += loss.item() * y.size(0)
        running_correct += (logits.argmax(1) == y).sum().item()
        running_total += y.size(0)

    scheduler.step()
    val_loss, val_acc = evaluate(val_loader)

    print(f"Epoch {epoch:02d} | "
          f"train loss {running_loss/running_total:.4f} | "
          f"train acc {running_correct/running_total:.4f} | "
          f"val loss {val_loss:.4f} | "
          f"val acc {val_acc:.4f}")

    # save LAST
    torch.save({
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "scheduler_state": scheduler.state_dict(),
        "best_val": best_val,
        "class_to_idx": class_to_idx,
        "T_STEPS": T_STEPS, "P": P, "phi0": phi0,
        "K_SPIKES": K_SPIKES, "JITTER": JITTER,
    }, last_ckpt_path)

    # save BEST
    if val_acc > best_val:
        best_val = val_acc
        torch.save({
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "scheduler_state": scheduler.state_dict(),
            "best_val": best_val,
            "class_to_idx": class_to_idx,
            "T_STEPS": T_STEPS, "P": P, "phi0": phi0,
            "K_SPIKES": K_SPIKES, "JITTER": JITTER,
        }, best_ckpt_path)
        print(f"Saved BEST: epoch {epoch} | best_val={best_val:.4f} -> {best_ckpt_path}")

print("\nLoading BEST checkpoint for test...")
best = torch.load(best_ckpt_path, map_location=device)
model.load_state_dict(best["model_state"])
test_loss, test_acc = evaluate(test_loader)
print(f"BEST epoch {best['epoch']} | best_val {best['best_val']:.4f} | Test loss {test_loss:.4f} | Test acc {test_acc:.4f}")


Removed: ./ucf_subset_ttfsphase_videoK4_last.pth
Removed: ./ucf_subset_ttfsphase_videoK4_best.pth
num_classes=10 | clip_frames=16 | encoder_T=60 | P=3 | M=20 | K=4 (<=K per pixel)
Epoch 01 | train loss 2.2898 | train acc 0.1351 | val loss 2.2557 | val acc 0.1500
Saved BEST: epoch 1 | best_val=0.1500 -> ./ucf_subset_ttfsphase_videoK4_best.pth
Epoch 02 | train loss 2.2010 | train acc 0.1588 | val loss 2.1934 | val acc 0.1200
Epoch 03 | train loss 2.0534 | train acc 0.2128 | val loss 2.0252 | val acc 0.2000
Saved BEST: epoch 3 | best_val=0.2000 -> ./ucf_subset_ttfsphase_videoK4_best.pth
Epoch 04 | train loss 2.0037 | train acc 0.1926 | val loss 2.6319 | val acc 0.1100
Epoch 05 | train loss 1.9880 | train acc 0.1993 | val loss 1.9005 | val acc 0.3000
Saved BEST: epoch 5 | best_val=0.3000 -> ./ucf_subset_ttfsphase_videoK4_best.pth
Epoch 06 | train loss 1.8912 | train acc 0.2365 | val loss 1.8604 | val acc 0.3400
Saved BEST: epoch 6 | best_val=0.3400 -> ./ucf_subset_ttfsphase_videoK4_best.pt

5. ISI-Phase

In [None]:
import os
import cv2
import numpy as np
import random
import pathlib

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import snntorch as snn

# =========================================================
# Config
# =========================================================
device = "cuda" if torch.cuda.is_available() else "cpu"

batch_size = 8
num_epochs = 80
lr = 2e-3

# video sampling
T_steps    = 16     # frames_per_clip (also encoder time steps)
frame_step = 2
H, W       = 112, 112

# ISI-PHASE encoder (strict timegrid)
K_spikes  = 4          # spikes per pixel/channel
P         = 2          # phase period (grid step)
phi0      = 0          # phase offset
alpha_max = 2.0
eps = 1e-3
agg_mode = "max"       # "max" or "mean" collapse clip->image before ISI-Phase

# sanity: number of allowed bins on grid
M = int(((T_steps - 1 - phi0) // P) + 1)
assert K_spikes <= M, f"K_spikes={K_spikes} must <= M={M}. Try smaller K or smaller P."

# LIF
tau = 2.0
beta = torch.exp(torch.tensor(-1.0 / tau)).item()

# checkpoint paths
last_ckpt_path = "./ucf_subset_isiphase_last.pth"
best_ckpt_path = "./ucf_subset_isiphase_best.pth"

torch.backends.cudnn.benchmark = True


# =========================================================
# Video utils
# =========================================================
def resize_with_pad(frame_bgr, out_hw=(112,112)):
    out_h, out_w = out_hw
    h, w = frame_bgr.shape[:2]
    if h == 0 or w == 0:
        return np.zeros((out_h, out_w, 3), dtype=np.uint8)

    scale = min(out_w / w, out_h / h)
    nw, nh = int(w * scale), int(h * scale)
    resized = cv2.resize(frame_bgr, (nw, nh), interpolation=cv2.INTER_AREA)

    canvas = np.zeros((out_h, out_w, 3), dtype=np.uint8)
    top = (out_h - nh) // 2
    left = (out_w - nw) // 2
    canvas[top:top+nh, left:left+nw] = resized
    return canvas

def frames_from_video(video_path, n_frames, out_hw=(112,112), frame_step=2, training=True):
    cap = cv2.VideoCapture(str(video_path))
    if not cap.isOpened():
        return np.zeros((n_frames, out_hw[0], out_hw[1], 3), dtype=np.uint8)

    length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    need_len = 1 + (n_frames - 1) * frame_step
    if length <= 0 or need_len > length:
        start = 0
    else:
        max_start = length - need_len
        start = random.randint(0, max_start) if training else 0

    cap.set(cv2.CAP_PROP_POS_FRAMES, start)

    frames = []
    ret, frame = cap.read()
    if not ret:
        cap.release()
        return np.zeros((n_frames, out_hw[0], out_hw[1], 3), dtype=np.uint8)

    frames.append(resize_with_pad(frame, out_hw))

    for _ in range(n_frames - 1):
        for _ in range(frame_step):
            ret, frame = cap.read()
        if ret:
            frames.append(resize_with_pad(frame, out_hw))
        else:
            frames.append(np.zeros_like(frames[0]))

    cap.release()

    # BGR -> RGB and make contiguous
    frames = np.stack(frames, axis=0)[..., ::-1].copy()
    return frames  # [T,H,W,3] uint8 RGB


# =========================================================
# Dataset
# =========================================================
IM_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
IM_STD  = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)

class UCFSubsetDataset(Dataset):
    def __init__(self, split_dir, n_frames=16, frame_step=2, training=False, class_to_idx=None):
        self.split_dir = pathlib.Path(split_dir)
        self.n_frames = n_frames
        self.frame_step = frame_step
        self.training = training

        self.video_paths = sorted(self.split_dir.glob("*/*.avi"))
        self.class_names = sorted({p.parent.name for p in self.video_paths})

        if class_to_idx is None:
            self.class_to_idx = {c:i for i,c in enumerate(self.class_names)}
        else:
            self.class_to_idx = class_to_idx

        self.labels = [self.class_to_idx[p.parent.name] for p in self.video_paths]

    def __len__(self):
        return len(self.video_paths)

    def __getitem__(self, idx):
        vp = self.video_paths[idx]
        y = self.labels[idx]

        frames = frames_from_video(vp, self.n_frames, (H,W), self.frame_step, self.training)
        x = torch.from_numpy(frames).permute(0,3,1,2).float() / 255.0   # [T,3,H,W]
        x = (x - IM_MEAN) / IM_STD                                       # ImageNet normalize
        return x, torch.tensor(y, dtype=torch.long)

def collate_btchw(batch):
    xs, ys = zip(*batch)
    x = torch.stack(xs, dim=0)  # [B,T,3,H,W]
    y = torch.stack(ys, dim=0)
    return x, y


# =========================================================
# ISI-Phase Encoder (STRICT phase-lock on time grid)
# =========================================================
@torch.no_grad()
def isi_phase_fixedK_strict_timegrid(
    x_img_unit: torch.Tensor,     # [B,3,H,W] in [0,1]
    T: int,
    K: int,
    P: int,
    phi0: int = 0,
    alpha_max: float = 2.0,
    eps: float = 1e-3,
) -> torch.Tensor:
    """
    Return spikes [T,B,3,H,W], exactly K spikes per pixel/channel.
    All spikes lie on phase-locked grid: t = phi0 + k*P.
    """
    assert 0 <= phi0 < P
    device = x_img_unit.device
    B, C, H, W = x_img_unit.shape

    x = x_img_unit.clamp(0.0, 1.0).view(B, -1)  # [B,N]
    N = x.size(1)

    M = int(((T - 1 - phi0) // P) + 1)
    if K > M:
        raise ValueError(f"K={K} must satisfy K<=M={M} (phase bins).")

    # ISI-like distribution over bins k=0..M-1
    k_grid = torch.arange(M, device=device, dtype=torch.float32).view(1, 1, M)
    mid = (M - 1) / 2.0

    alpha = (x * 2.0 - 1.0) * alpha_max
    alpha = alpha.unsqueeze(-1)  # [B,N,1]

    w = torch.exp(alpha * (k_grid - mid))                  # [B,N,M]
    w = w / (w.sum(dim=-1, keepdim=True) + 1e-12)
    cdf = torch.cumsum(w, dim=-1).contiguous()             # [B,N,M]

    # quantiles
    q = torch.linspace(eps, 1.0 - eps, steps=K, device=device, dtype=torch.float32)
    q = q.view(1, 1, K).expand(B, N, K).contiguous()

    k_idx = torch.searchsorted(cdf, q).clamp(0, M - 1).long()  # [B,N,K]
    k_idx, _ = torch.sort(k_idx, dim=-1)

    # strict uniqueness in bins
    used = torch.zeros(B, N, M, device=device, dtype=torch.bool)
    k_fixed = torch.full_like(k_idx, -1)

    for kk in range(K):
        k0 = k_idx[..., kk]
        free = ~used.gather(dim=2, index=k0.unsqueeze(-1)).squeeze(-1)
        k_fixed[..., kk] = torch.where(free, k0, torch.full_like(k0, -1))
        if free.any():
            used[free] |= F.one_hot(k0[free], num_classes=M).bool()

    # resolve collisions: nearest free bin (forward then backward)
    for kk in range(K):
        need = (k_fixed[..., kk] < 0)
        if not need.any():
            continue

        k0 = k_idx[..., kk].clone()
        avail = ~used
        ar = torch.arange(M, device=device).view(1, 1, M)

        forward_mask = avail & (ar >= k0.unsqueeze(-1))
        fwd_pos = forward_mask.float().argmax(dim=-1)
        fwd_exists = forward_mask.any(dim=-1)

        backward_mask = avail & (ar <= k0.unsqueeze(-1))
        rev = torch.flip(backward_mask, dims=[-1])
        bwd_pos_rev = rev.float().argmax(dim=-1)
        bwd_pos = (M - 1) - bwd_pos_rev
        bwd_exists = backward_mask.any(dim=-1)

        chosen = torch.where(fwd_exists, fwd_pos, bwd_pos).long()
        k_fixed[..., kk] = torch.where(need, chosen, k_fixed[..., kk])
        used[need] |= F.one_hot(chosen[need], num_classes=M).bool()

    # bins -> time indices
    t_idx = (phi0 + k_fixed * P).long()  # [B,N,K], all in [0..T-1]

    spk_flat = torch.zeros(T, B, N, device=device, dtype=torch.float32)
    b_idx = torch.arange(B, device=device).view(B, 1, 1).expand(B, N, K)
    n_idx = torch.arange(N, device=device).view(1, N, 1).expand(B, N, K)
    spk_flat[t_idx, b_idx, n_idx] = 1.0

    return spk_flat.view(T, B, C, H, W)


@torch.no_grad()
def isiphase_encode_ucf_video(
    x_video_norm: torch.Tensor,   # [B,T,3,H,W] normalized
    T_steps: int,
    K: int,
    P: int,
    phi0: int,
    alpha_max: float = 2.0,
    eps: float = 1e-3,
    agg="max",
):
    mean = IM_MEAN.to(x_video_norm.device)
    std  = IM_STD.to(x_video_norm.device)
    x_raw = (x_video_norm * std + mean).clamp(0.0, 1.0)  # [B,T,3,H,W] in [0,1]

    if agg == "max":
        x_img = x_raw.max(dim=1).values   # [B,3,H,W]
    elif agg == "mean":
        x_img = x_raw.mean(dim=1)
    else:
        raise ValueError("agg must be 'max' or 'mean'")

    return isi_phase_fixedK_strict_timegrid(
        x_img, T=T_steps, K=K, P=P, phi0=phi0, alpha_max=alpha_max, eps=eps
    )  # [T,B,3,H,W]


# =========================================================
# Conv3D SNN (analog logits head)
# =========================================================
class Conv3DSNN_ISIPhase(nn.Module):
    def __init__(self, beta=0.95, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv3d(3, 32, 3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm3d(32)
        self.lif1 = snn.Leaky(beta=beta)

        self.conv2 = nn.Conv3d(32, 64, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm3d(64)
        self.lif2 = snn.Leaky(beta=beta)

        self.conv3 = nn.Conv3d(64, 128, 3, padding=1, bias=False)
        self.bn3 = nn.BatchNorm3d(128)
        self.lif3 = snn.Leaky(beta=beta)

        self.pool = nn.MaxPool3d((1,2,2))
        self.gap = nn.AdaptiveAvgPool3d((None,1,1))

        self.fc1 = nn.Linear(128, 256)
        self.lif4 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(256, num_classes)

    def _lif_over_time_2d(self, x_bcthw, lif, mem):
        B,C,T,H,W = x_bcthw.shape
        spk_list=[]
        for t in range(T):
            spk_t, mem = lif(x_bcthw[:, :, t], mem)
            spk_list.append(spk_t)
        return torch.stack(spk_list, dim=2), mem

    def forward(self, spk_in):
        # spk_in: [T,B,3,H,W]
        x = spk_in.permute(1,2,0,3,4).contiguous()  # [B,3,T,H,W]

        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()
        mem4 = self.lif4.init_leaky()

        x = self.conv1(x); x = self.bn1(x); x, mem1 = self._lif_over_time_2d(x, self.lif1, mem1); x = self.pool(x)
        x = self.conv2(x); x = self.bn2(x); x, mem2 = self._lif_over_time_2d(x, self.lif2, mem2); x = self.pool(x)
        x = self.conv3(x); x = self.bn3(x); x, mem3 = self._lif_over_time_2d(x, self.lif3, mem3); x = self.pool(x)

        x = self.gap(x).squeeze(-1).squeeze(-1)  # [B,128,T]

        logits_list=[]
        for t in range(x.shape[2]):
            h = self.fc1(x[:,:,t])
            spk4, mem4 = self.lif4(h, mem4)
            logits_t = self.fc2(spk4)
            logits_list.append(logits_t)

        return torch.stack(logits_list, dim=0)  # [T,B,num_classes]


# =========================================================
# Build loaders from subset_paths (must exist)
# =========================================================
train_ds = UCFSubsetDataset(subset_paths["train"], n_frames=T_steps, frame_step=frame_step, training=True)
class_to_idx = train_ds.class_to_idx
val_ds   = UCFSubsetDataset(subset_paths["val"],   n_frames=T_steps, frame_step=frame_step, training=False, class_to_idx=class_to_idx)
test_ds  = UCFSubsetDataset(subset_paths["test"],  n_frames=T_steps, frame_step=frame_step, training=False, class_to_idx=class_to_idx)

num_classes = len(class_to_idx)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True,
                          num_workers=2, pin_memory=True, collate_fn=collate_btchw)
val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, drop_last=False,
                          num_workers=2, pin_memory=True, collate_fn=collate_btchw)
test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, drop_last=False,
                          num_workers=2, pin_memory=True, collate_fn=collate_btchw)

print(f"num_classes = {num_classes} | T_steps = {T_steps} | train batches = {len(train_loader)} | "
      f"ISI-Phase: K={K_spikes}, P={P}, phi0={phi0}, M={M}, agg={agg_mode}")


# =========================================================
# Train / Eval + BEST checkpoint
# =========================================================
model = Conv3DSNN_ISIPhase(beta=beta, num_classes=num_classes).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

start_epoch = 1
best_val = -1.0

# optional resume from LAST
if os.path.exists(last_ckpt_path):
    ckpt = torch.load(last_ckpt_path, map_location=device)
    model.load_state_dict(ckpt["model_state"])
    optimizer.load_state_dict(ckpt["optimizer_state"])
    scheduler.load_state_dict(ckpt["scheduler_state"])
    start_epoch = ckpt["epoch"] + 1
    best_val = ckpt.get("best_val", -1.0)
    print("Resumed from epoch", start_epoch, "best_val", best_val)

@torch.no_grad()
def evaluate(loader):
    model.eval()
    correct, total, loss_sum = 0, 0, 0.0
    for x, y in loader:
        x = x.to(device, non_blocking=True)  # [B,T,3,H,W]
        y = y.to(device, non_blocking=True)

        spk_in = isiphase_encode_ucf_video(
            x, T_steps=T_steps, K=K_spikes, P=P, phi0=phi0, alpha_max=alpha_max, eps=eps, agg=agg_mode
        )  # [T,B,3,H,W]

        logits_rec = model(spk_in)           # [T,B,C]
        logits = logits_rec.mean(dim=0)      # [B,C]
        loss = F.cross_entropy(logits, y)

        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.numel()
        loss_sum += loss.item() * y.size(0)

    return loss_sum / max(total, 1), correct / max(total, 1)

for epoch in range(start_epoch, num_epochs + 1):
    model.train()
    running_loss, running_correct, running_total = 0.0, 0, 0

    for x, y in train_loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        spk_in = isiphase_encode_ucf_video(
            x, T_steps=T_steps, K=K_spikes, P=P, phi0=phi0, alpha_max=alpha_max, eps=eps, agg=agg_mode
        )

        logits_rec = model(spk_in)
        logits = logits_rec.mean(dim=0)
        loss = F.cross_entropy(logits, y)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        running_loss += loss.item() * y.size(0)
        running_correct += (logits.argmax(dim=1) == y).sum().item()
        running_total += y.size(0)

    scheduler.step()
    val_loss, val_acc = evaluate(val_loader)

    print(f"Epoch {epoch:02d} | "
          f"train loss {running_loss/running_total:.4f} | "
          f"train acc {running_correct/running_total:.4f} | "
          f"val loss {val_loss:.4f} | "
          f"val acc {val_acc:.4f}")

    # save LAST
    torch.save({
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "scheduler_state": scheduler.state_dict(),
        "best_val": best_val,
        "class_to_idx": class_to_idx,
        "T_steps": T_steps,
        "K_spikes": K_spikes,
        "P": P,
        "phi0": phi0,
        "alpha_max": alpha_max,
        "agg_mode": agg_mode,
    }, last_ckpt_path)

    # save BEST
    if val_acc > best_val:
        best_val = val_acc
        torch.save({
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "scheduler_state": scheduler.state_dict(),
            "best_val": best_val,
            "class_to_idx": class_to_idx,
            "T_steps": T_steps,
            "K_spikes": K_spikes,
            "P": P,
            "phi0": phi0,
            "alpha_max": alpha_max,
            "agg_mode": agg_mode,
        }, best_ckpt_path)
        print(f"Saved BEST: epoch {epoch} | best_val={best_val:.4f} -> {best_ckpt_path}")

print("\nLoading BEST checkpoint for test...")
best = torch.load(best_ckpt_path, map_location=device)
model.load_state_dict(best["model_state"])
test_loss, test_acc = evaluate(test_loader)
print(f"BEST epoch {best['epoch']} | best_val {best['best_val']:.4f} | Test loss {test_loss:.4f} | Test acc {test_acc:.4f}")


num_classes = 10 | T_steps = 16 | train batches = 37 | ISI-Phase: K=4, P=2, phi0=0, M=8, agg=max
Epoch 01 | train loss 2.2863 | train acc 0.0946 | val loss 2.2057 | val acc 0.2300
Saved BEST: epoch 1 | best_val=0.2300 -> ./ucf_subset_isiphase_best.pth
Epoch 02 | train loss 2.1671 | train acc 0.1757 | val loss 2.0058 | val acc 0.2000
Epoch 03 | train loss 2.0578 | train acc 0.1959 | val loss 2.0957 | val acc 0.1700
Epoch 04 | train loss 2.0612 | train acc 0.2230 | val loss 1.9685 | val acc 0.2800
Saved BEST: epoch 4 | best_val=0.2800 -> ./ucf_subset_isiphase_best.pth
Epoch 05 | train loss 2.0633 | train acc 0.1993 | val loss 1.9312 | val acc 0.2400
Epoch 06 | train loss 1.9822 | train acc 0.2432 | val loss 2.2642 | val acc 0.1900
Epoch 07 | train loss 1.9482 | train acc 0.2432 | val loss 2.2099 | val acc 0.1600
Epoch 08 | train loss 1.9311 | train acc 0.2703 | val loss 1.8345 | val acc 0.3400
Saved BEST: epoch 8 | best_val=0.3400 -> ./ucf_subset_isiphase_best.pth
Epoch 09 | train loss 1

6. DNN

In [None]:
import os
import cv2
import numpy as np
import random
import pathlib

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


# =========================================================
# Config
# =========================================================
device = "cuda" if torch.cuda.is_available() else "cpu"

batch_size = 8
num_epochs = 80
lr = 2e-3
weight_decay = 1e-4

# video sampling
T_steps    = 16     # frames_per_clip
frame_step = 2
H, W       = 112, 112

# paths (your subset structure must already exist)
# subset_paths = {"train": "...", "val": "...", "test": "..."}  # should exist from your downloader
# Example:
# subset_paths = {"train":"./UCF101_subset/train", "val":"./UCF101_subset/val", "test":"./UCF101_subset/test"}

last_ckpt_path = "./ucf_subset_3dcnn_last.pth"
best_ckpt_path = "./ucf_subset_3dcnn_best.pth"

torch.backends.cudnn.benchmark = True


# =========================================================
# Video utils
# =========================================================
def resize_with_pad(frame_bgr, out_hw=(112,112)):
    out_h, out_w = out_hw
    h, w = frame_bgr.shape[:2]
    if h == 0 or w == 0:
        return np.zeros((out_h, out_w, 3), dtype=np.uint8)

    scale = min(out_w / w, out_h / h)
    nw, nh = int(w * scale), int(h * scale)
    resized = cv2.resize(frame_bgr, (nw, nh), interpolation=cv2.INTER_AREA)

    canvas = np.zeros((out_h, out_w, 3), dtype=np.uint8)
    top = (out_h - nh) // 2
    left = (out_w - nw) // 2
    canvas[top:top+nh, left:left+nw] = resized
    return canvas

def frames_from_video(video_path, n_frames, out_hw=(112,112), frame_step=2, training=True):
    cap = cv2.VideoCapture(str(video_path))
    if not cap.isOpened():
        return np.zeros((n_frames, out_hw[0], out_hw[1], 3), dtype=np.uint8)

    length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    need_len = 1 + (n_frames - 1) * frame_step
    if length <= 0 or need_len > length:
        start = 0
    else:
        max_start = length - need_len
        start = random.randint(0, max_start) if training else 0

    cap.set(cv2.CAP_PROP_POS_FRAMES, start)

    frames = []
    ret, frame = cap.read()
    if not ret:
        cap.release()
        return np.zeros((n_frames, out_hw[0], out_hw[1], 3), dtype=np.uint8)

    frames.append(resize_with_pad(frame, out_hw))

    for _ in range(n_frames - 1):
        for _ in range(frame_step):
            ret, frame = cap.read()
        if ret:
            frames.append(resize_with_pad(frame, out_hw))
        else:
            frames.append(np.zeros_like(frames[0]))

    cap.release()

    # BGR -> RGB, contiguous (avoid negative strides)
    frames = np.stack(frames, axis=0)[..., ::-1].copy()
    return frames  # [T,H,W,3] uint8 RGB


# =========================================================
# Dataset
# =========================================================
IM_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
IM_STD  = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)

class UCFSubsetDataset(Dataset):
    def __init__(self, split_dir, n_frames=16, frame_step=2, training=False, class_to_idx=None):
        self.split_dir = pathlib.Path(split_dir)
        self.n_frames = n_frames
        self.frame_step = frame_step
        self.training = training

        self.video_paths = sorted(self.split_dir.glob("*/*.avi"))
        self.class_names = sorted({p.parent.name for p in self.video_paths})

        if class_to_idx is None:
            self.class_to_idx = {c:i for i,c in enumerate(self.class_names)}
        else:
            self.class_to_idx = class_to_idx

        self.labels = [self.class_to_idx[p.parent.name] for p in self.video_paths]

    def __len__(self):
        return len(self.video_paths)

    def __getitem__(self, idx):
        vp = self.video_paths[idx]
        y = self.labels[idx]

        frames = frames_from_video(vp, self.n_frames, (H,W), self.frame_step, self.training)
        x = torch.from_numpy(frames).permute(0,3,1,2).float() / 255.0   # [T,3,H,W]
        x = (x - IM_MEAN) / IM_STD
        return x, torch.tensor(y, dtype=torch.long)

def collate_btchw(batch):
    xs, ys = zip(*batch)
    x = torch.stack(xs, dim=0)  # [B,T,3,H,W]
    y = torch.stack(ys, dim=0)
    return x, y


# =========================================================
# Plain 3D CNN (no encoding, no SNN)
# =========================================================
class Plain3DCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv3d(3, 32, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm3d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(kernel_size=(1,2,2))  # keep time
        )

        self.block1 = nn.Sequential(
            nn.Conv3d(32, 64, 3, padding=1, bias=False),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(kernel_size=(2,2,2))  # downsample time+spatial
        )

        self.block2 = nn.Sequential(
            nn.Conv3d(64, 128, 3, padding=1, bias=False),
            nn.BatchNorm3d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(kernel_size=(2,2,2))
        )

        self.block3 = nn.Sequential(
            nn.Conv3d(128, 256, 3, padding=1, bias=False),
            nn.BatchNorm3d(256),
            nn.ReLU(inplace=True),
        )

        self.gap = nn.AdaptiveAvgPool3d((1,1,1))
        self.fc  = nn.Linear(256, num_classes)

    def forward(self, x_btchw):
        # x: [B,T,3,H,W] -> [B,3,T,H,W]
        x = x_btchw.permute(0,2,1,3,4).contiguous()

        x = self.stem(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)

        x = self.gap(x).flatten(1)   # [B,256]
        logits = self.fc(x)          # [B,num_classes]
        return logits


# =========================================================
# Build loaders
# =========================================================
# ---- IMPORTANT: subset_paths must exist in your runtime ----
# Example if you need:
# subset_paths = {"train":"./UCF101_subset/train", "val":"./UCF101_subset/val", "test":"./UCF101_subset/test"}

train_ds = UCFSubsetDataset(subset_paths["train"], n_frames=T_steps, frame_step=frame_step, training=True)
class_to_idx = train_ds.class_to_idx
val_ds   = UCFSubsetDataset(subset_paths["val"],   n_frames=T_steps, frame_step=frame_step, training=False, class_to_idx=class_to_idx)
test_ds  = UCFSubsetDataset(subset_paths["test"],  n_frames=T_steps, frame_step=frame_step, training=False, class_to_idx=class_to_idx)

num_classes = len(class_to_idx)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True,
                          num_workers=2, pin_memory=True, collate_fn=collate_btchw)
val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, drop_last=False,
                          num_workers=2, pin_memory=True, collate_fn=collate_btchw)
test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, drop_last=False,
                          num_workers=2, pin_memory=True, collate_fn=collate_btchw)

print(f"num_classes = {num_classes} | T_steps = {T_steps} | train batches = {len(train_loader)}")


# =========================================================
# Train / Eval + BEST checkpoint
# =========================================================
model = Plain3DCNN(num_classes=num_classes).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

start_epoch = 1
best_val = -1.0

# optional resume from LAST
if os.path.exists(last_ckpt_path):
    ckpt = torch.load(last_ckpt_path, map_location=device)
    model.load_state_dict(ckpt["model_state"])
    optimizer.load_state_dict(ckpt["optimizer_state"])
    scheduler.load_state_dict(ckpt["scheduler_state"])
    start_epoch = ckpt["epoch"] + 1
    best_val = ckpt.get("best_val", -1.0)
    print("Resumed from epoch", start_epoch, "best_val", best_val)

@torch.no_grad()
def evaluate(loader):
    model.eval()
    correct, total, loss_sum = 0, 0, 0.0
    for x, y in loader:
        x = x.to(device, non_blocking=True)  # [B,T,3,H,W]
        y = y.to(device, non_blocking=True)

        logits = model(x)                    # [B,C]
        loss = F.cross_entropy(logits, y)

        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.numel()
        loss_sum += loss.item() * y.size(0)

    return loss_sum / max(total, 1), correct / max(total, 1)

for epoch in range(start_epoch, num_epochs + 1):
    model.train()
    running_loss, running_correct, running_total = 0.0, 0, 0

    for x, y in train_loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        logits = model(x)
        loss = F.cross_entropy(logits, y)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        running_loss += loss.item() * y.size(0)
        running_correct += (logits.argmax(dim=1) == y).sum().item()
        running_total += y.size(0)

    scheduler.step()
    val_loss, val_acc = evaluate(val_loader)

    print(f"Epoch {epoch:02d} | "
          f"train loss {running_loss/running_total:.4f} | "
          f"train acc {running_correct/running_total:.4f} | "
          f"val loss {val_loss:.4f} | "
          f"val acc {val_acc:.4f}")

    # -------- save LAST --------
    torch.save({
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "scheduler_state": scheduler.state_dict(),
        "best_val": best_val,
        "class_to_idx": class_to_idx,
        "T_steps": T_steps,
        "frame_step": frame_step,
        "H": H, "W": W,
    }, last_ckpt_path)

    # -------- save BEST --------
    if val_acc > best_val:
        best_val = val_acc
        torch.save({
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "scheduler_state": scheduler.state_dict(),
            "best_val": best_val,
            "class_to_idx": class_to_idx,
            "T_steps": T_steps,
            "frame_step": frame_step,
            "H": H, "W": W,
        }, best_ckpt_path)
        print(f"Saved BEST: epoch {epoch} | best_val={best_val:.4f} -> {best_ckpt_path}")


# =========================================================
# Test using BEST checkpoint
# =========================================================
print("\nLoading BEST checkpoint for test...")
best = torch.load(best_ckpt_path, map_location=device)
model.load_state_dict(best["model_state"])

test_loss, test_acc = evaluate(test_loader)
print(f"BEST epoch {best['epoch']} | best_val {best['best_val']:.4f} | Test loss {test_loss:.4f} | Test acc {test_acc:.4f}")


num_classes = 10 | T_steps = 16 | train batches = 37
Epoch 01 | train loss 2.0690 | train acc 0.2872 | val loss 2.0641 | val acc 0.2500
Saved BEST: epoch 1 | best_val=0.2500 -> ./ucf_subset_3dcnn_best.pth
Epoch 02 | train loss 1.7658 | train acc 0.3682 | val loss 1.8308 | val acc 0.3300
Saved BEST: epoch 2 | best_val=0.3300 -> ./ucf_subset_3dcnn_best.pth
Epoch 03 | train loss 1.5896 | train acc 0.4392 | val loss 2.2119 | val acc 0.2200
Epoch 04 | train loss 1.3971 | train acc 0.5068 | val loss 3.2156 | val acc 0.2200
Epoch 05 | train loss 1.3402 | train acc 0.5236 | val loss 1.8561 | val acc 0.3800
Saved BEST: epoch 5 | best_val=0.3800 -> ./ucf_subset_3dcnn_best.pth
Epoch 06 | train loss 1.2737 | train acc 0.5101 | val loss 1.6452 | val acc 0.3400
Epoch 07 | train loss 1.1396 | train acc 0.6047 | val loss 1.9412 | val acc 0.3100
Epoch 08 | train loss 1.0957 | train acc 0.5912 | val loss 1.5799 | val acc 0.4200
Saved BEST: epoch 8 | best_val=0.4200 -> ./ucf_subset_3dcnn_best.pth
Epoch 0