In [1]:
'''
    resnet 3D CNN + kinetics-400 pretrained + hmdb51 training
    Reference paper: Can Spatiotemporal 3D CNNs Retrace the History of 2D CNNs and ImageNet?
'''

import torch
import torchvision
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from torchsummaryX import summary
import torch.optim as optim
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter

import pretrained_model as model
import os
import time
import copy
import warnings
import utils

In [2]:
warnings.filterwarnings("ignore")

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
model_ft = model.get_model().to(device)
criterion =  nn.CrossEntropyLoss()
optimizer = optim.Adam(model_ft.parameters())
log_name = time.strftime("%Y-%m-%d %H:%M:%S",time.localtime(time.time()))
writer = SummaryWriter(os.path.join('log/', log_name))

print(model_ft)
summary(model_ft, torch.zeros(1, 3, 8, 128, 128))

VideoResNet(
  (stem): BasicStem(
    (0): Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False)
    (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Sequential(
        (0): Conv3DSimple(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
      (conv2): Sequential(
        (0): Conv3DSimple(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU(inplace=True)
    )
    (1): BasicBlock(
      (conv1): Sequential(
        (0): Conv3DSimple(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (1):

RuntimeError: Could not run 'aten::slow_conv3d_forward' with arguments from the 'CUDATensorId' backend. 'aten::slow_conv3d_forward' is only available for these backends: [CPUTensorId, VariableTensorId].

In [3]:
best_val_acc = 0.0
best_val_model = None
EPOCH = 10
BSZ = 2

transform = transforms.Compose([utils.RandomResizedCropVideo(128), utils.ToTensorVideo()])

data_root = '../../Data/HMDB'

train_loader = DataLoader(datasets.HMDB51(root = os.path.join(data_root, 'hmdb51'), 
                                          annotation_path = os.path.join(data_root,'splits'),
                                          frames_per_clip = 8, fold = 1, train = True, 
                                          transform = transform), 
                          batch_size = BSZ, shuffle = True, num_workers = 4)

val_loader = DataLoader(datasets.HMDB51(root = os.path.join(data_root, 'hmdb51'), 
                                        annotation_path = os.path.join(data_root, 'splits'),
                                        frames_per_clip = 8, fold = 2, train = False, 
                                        transform = transform), 
                        batch_size = BSZ, shuffle = True, num_workers = 4)

test_loader = DataLoader(datasets.HMDB51(root = os.path.join(data_root, 'hmdb51'), 
                                         annotation_path = os.path.join(data_root, 'splits'),
                                         frames_per_clip = 8, fold = 3, train = False, 
                                         transform = transform), 
                         batch_size = BSZ, shuffle = True, num_workers = 4)



sample_size = next(iter(train_loader))[0].size()
assert sample_size == torch.Size([BSZ, 3, 8, 128, 128]), 'sample_size is {}'.format(sample_size)

  "follow-up version. Please use pts_unit 'sec'.")
100.0%
  "follow-up version. Please use pts_unit 'sec'.")
100.0%
  "follow-up version. Please use pts_unit 'sec'.")
100.0%
  "follow-up version. Please use pts_unit 'sec'.")
  "follow-up version. Please use pts_unit 'sec'.")
  "follow-up version. Please use pts_unit 'sec'.")
  "follow-up version. Please use pts_unit 'sec'.")
  "See the documentation of nn.Upsample for details.".format(mode))
  "See the documentation of nn.Upsample for details.".format(mode))
  "See the documentation of nn.Upsample for details.".format(mode))
  "See the documentation of nn.Upsample for details.".format(mode))


In [None]:
for epoch in range(EPOCH):
    
    start_time = time.strftime("%H:%M:%S", time.localtime(time.time()))
    print(f'epoch {epoch} | start time {start_time}')

    train_loss, train_acc = model.train(model_ft, train_loader, criterion, optimizer, epoch, writer, device)
    val_loss, val_acc = model.evaluate(model_ft, val_loader, criterion, epoch, writer, device)

    print(f'train loss {train_loss:03f} | train accuracy {train_acc:03f}')
    print(f'val loss {val_loss:03f} | val accuracy {val_acc:03f}\n')

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_val_model = copy.deepcopy(model_ft.state_dict())

model_ft.load_state_dict(best_val_model)
test_loss, test_acc = model.evaluate(model_ft, test_loader, criterion, EPOCH+1, writer, device)
print(f'test_loss {test_loss:03f} | test acc {test_acc:03f}')



epoch 0 | start time 11:42:19


  "follow-up version. Please use pts_unit 'sec'.")
  "follow-up version. Please use pts_unit 'sec'.")
  "follow-up version. Please use pts_unit 'sec'.")
  "follow-up version. Please use pts_unit 'sec'.")
  "See the documentation of nn.Upsample for details.".format(mode))
  "See the documentation of nn.Upsample for details.".format(mode))
  "See the documentation of nn.Upsample for details.".format(mode))
  "See the documentation of nn.Upsample for details.".format(mode))
