### Training with augmentation Herve 1- RUNNING
Code taken from: https://github.com/piergiaj/pytorch-i3d/blob/master/train_i3d.py 

Note: This code was written for PyTorch 0.3. Version 0.4 and newer may cause issues.

# TO CHANGE BEFORE RUNNING

Set `augment = True` below for data to be augmented, and `false` otherwise.

In [1]:
is_augment = False

Set `dropout = True` below for dropout to be included, and `false` otherwise.

In [2]:
is_dropout = False

Dropout details:

In [3]:
dropout_details = "layer1_p0.5"

Learning rate

In [4]:
learning_rate = 0.1

Set `l2 = True` below for L2 Regularization, and `false` for L1 Regularization.

In [5]:
l2 = False

Set weight decay value, `wd`, for L2 Regularization

In [6]:
wd = None

Set `lambda` for L1 Regularization 

In [7]:
lambda1 = 1e-2

Set the number of epochs in training:

In [8]:
num_epochs = 30

**ALL FILES INCLUDING LOSSES AND THE MODEL WILL BE SAVED WITH THIS NAME:**

In [9]:
# "30epochs_wd_1e-07_dropout__augmented" means the there are 30 training epochs, weight decay is 1e-07, and that there is dropout and augmentation
save_name = f"{num_epochs}epochs"
if (not l2): save_name = save_name + "_l1_lr_" + str(learning_rate) + "_ld_" + str(lambda1) # l1 regularization
if l2: save_name = save_name + "_l2_lr_" + str(learning_rate) + "_wd_"+ str(wd) # l2 regularization
if is_dropout: save_name = save_name + "_dropout_"+dropout_details
if is_augment: save_name = save_name + "_augment"

In [10]:
# check save_name
save_name

'30epochs_l1_lr_0.1_ld_0.01'

# CODE

Import packages

In [11]:
import os
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, precision_score
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]='2'
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:<1024>"
import sys
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
import torchvision
from torchvision import datasets, transforms
import numpy as np
from pytorch_i3d import InceptionI3d
import numpy as np
import glob
import random
from tensorboardX import SummaryWriter
from preprocess import run_preprocessing, get_action, holdout_set
import time
import matplotlib.pyplot as plt
import pandas as pd
from collections import Counter
from PIL import Image, ImageSequence

# video augmentation scripts (c) 2018 okankop
from vidaug import *

Construct a dataset class for training the model:

In [12]:
class dataset(torch.utils.data.Dataset):
    
    def __init__(self, paths, v_names, v_labels, num_samples=16, transforms=None): # num_samples cannot be lower than 16
        self.num_samples = num_samples
        self.frames = dict()
        for p in paths:
            self.frames[p] = sorted(glob.glob(p+"/*.jpg"))
        self.data = paths
        self.video_names = v_names
        self.video_labels = v_labels
        self.transforms = transforms
    
    def __getitem__(self, idx):
        # get original video
        p = self.data[idx]
        
        # sample frames uniformly and create newly sampled video 
        num_frames = len(self.frames[p])-1
        sampled_idx = np.linspace(0, num_frames, self.num_samples) #get num_samples frames from the video
        images = []
        index = np.where(self.video_names == p.split('/')[-1]) #index of p's video name in video_names
        label_video = self.video_labels[index] # the labels for the video
        for i in sampled_idx:
            image = torchvision.io.read_image(self.frames[p][int(i)])
            small_dim = min(image.shape[-2:])
            image = torchvision.transforms.functional.center_crop(image, (small_dim, small_dim))
            image = torchvision.transforms.functional.resize(image, (224, 224), antialias=True)
            images.append(image)
        images = torch.stack(images, axis=1)
        
        # data augmentation 
        if (self.transforms is not None):
            images = np.array(self.transforms(images.numpy()))
            # normalize
            images = (images/255)*2 - 1 # values are between -1 and 1
            return torch.from_numpy(images).type(torch.FloatTensor), label_video 
        
        else: 
            images = (images/255)*2 - 1 #values are between -1 and 1
            return images, label_video 

    def __len__(self):
        return len(self.data)

Build transformations for data augmentation

In [13]:
if is_augment:
    sometimes = lambda aug: Sometimes(0.4, aug) # Used to apply augmentor with 40% probability
    rand_aug = SomeOf([ # randomly chooses two of the following augmentation methods 
        RandomRotate(degrees=10), # randomly rotates the video with a degree randomly choosen from [-10, 10] 
        RandomTranslate(x=40,y=20), # randomly shifting video in [-x, +x] and [-y, +y] coordinate
        RandomShear(x=0.2,y=0.1), # randomly shearing video in [-x, +x] and [-y, +y] directions.
        sometimes(HorizontalFlip()), # horizontally flip the video with 50% probability
        sometimes(GaussianBlur(sigma=random.uniform(0.5,4))), # blur images using gaussian kernels with std. dev. = sigma
        sometimes(ElasticTransformation(alpha=random.uniform(0,5), cval=int(random.uniform(0,255)), mode="nearest")), # moving pixels locally around using displacement fields
        sometimes(PiecewiseAffineTransform(displacement=15, displacement_kernel=1, displacement_magnification=1)), # places a regular grid of points on an image and randomly moves the neighbourhood of these point around via affine transformations
        sometimes(Add(value=int(random.uniform(-100,100)))), # add a value to all pixel intesities in an video
        sometimes(Multiply(value=2)), # multiply all pixel intensities with given value
        sometimes(Multiply(value=0.5)), # multiply all pixel intensities with given value
        sometimes(Pepper(ratio=25)), # sets a certain fraction of pixel intensities to 0
        sometimes(Salt(ratio=25)), # sets a certain fraction of pixel intensities to 255
    ], 2) # only select two of the above augmenters each time

Extract data and labels

In [14]:
video_train, video_val, label_train, label_val, unique_labels = holdout_set(0.25) #valid names and videos
batch_size = 10 # batch size in training
num_videos_train = len(video_train)
num_videos_val = len(video_val)
num_classes = len(set(label_train)) #count unique in labels

video_frames_path = "/scratch/network/hishimwe/image" 
# only extract the videos with v_names and v_labels from preprocess.ipynb 
paths = glob.glob(video_frames_path+"/*")
random.seed(0)
random.shuffle(paths)

good_paths_train = list(filter(lambda c: c.split('/')[-1] in video_train, paths)) #should only get path where good video name; not sure if this filtering will work 
good_paths_val = list(filter(lambda c: c.split('/')[-1] in video_val, paths)) # validation video paths 

if is_augment: d_train = dataset(paths=good_paths_train, v_names=video_train, v_labels= label_train, transforms=rand_aug)
else: d_train = dataset(paths=good_paths_train, v_names=video_train, v_labels= label_train)
d_val = dataset(paths=good_paths_val, v_names=video_val, v_labels= label_val)

loader_train = torch.utils.data.DataLoader(d_train, shuffle=True, batch_size=batch_size, drop_last=False, num_workers=4)
loader_val = torch.utils.data.DataLoader(d_val, shuffle=True, batch_size=batch_size, drop_last=False, num_workers=4)

Construct the model:

In [15]:
start_time = time.time() 
i3d = InceptionI3d(400, in_channels=3) # first input is num_classes in kinetics, this is replaced with replace_logits

if is_dropout: i3d.load_state_dict(torch.load('rgb_imagenet.pt'), strict=False) #added strict = false; theoretically this lets us add layers
else: i3d.load_state_dict(torch.load('rgb_imagenet.pt')) 

i3d.replace_logits(num_classes)
i3d.cuda()

print(f"time taken: {time.time()-start_time} seconds")

time taken: 1.1249988079071045 seconds


Function to evaluate model performance:

In [16]:
#returns accuracy, f1 score, average f1, and confusion matrix for the data
def eval_metrics(ground_truth, predictions, num_classes):

    #dictionary containing the accuracy, precision, f1, avg f1, and confusion matrix for the data
    f1 = f1_score(y_true=ground_truth, y_pred=predictions, labels=np.arange(num_classes), average=None)
    metrics = {
        "accuracy": accuracy_score(y_true=ground_truth, y_pred=predictions),
        "f1": f1,
        "average f1": np.mean(f1),
        "confusion matrix": confusion_matrix(y_true=ground_truth, y_pred=predictions, labels=np.arange(num_classes)),
        "precision": precision_score(y_true=ground_truth, y_pred=predictions, labels=np.arange(num_classes), average=None)
        }
    
    return metrics

Function to train and validate:

In [17]:
def training(model, optimizer, loader, num_classes, reg_type, ld=None):
    losses = []
    ground_truth = []
    predictions = []
    for data, label in loader:
        data = data.cuda()
        label = label.squeeze().type(torch.LongTensor).cuda()
        num_frames = data.size(2)
        per_frame_logits = i3d(data).mean(2)
        preds = per_frame_logits.cpu().detach().numpy().argmax(axis=1) # convert logits into predictions for evaluating accuracy
        
        # calculate and save loss
        loss = F.cross_entropy(per_frame_logits, label)
        losses.append(loss.item()) # append to losses
        ground_truth.extend(list(label.cpu().detach().numpy()))
        predictions.extend(preds.tolist())
        
        if (not reg_type): # l1 regularization
            params = torch.cat([p.view(-1) for p in model.parameters()]) # weights
            norm = torch.norm(params, 1)
            loss = loss - (ld * norm) # updating loss
             
        # back propagation    
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    metrics = eval_metrics(ground_truth, predictions, num_classes)   
    return np.mean(losses), metrics # one loss per epoch and the corresponding metrics        


In [18]:
def evaluate(model, loader, num_classes):
    losses = []
    ground_truth = []
    predictions = []
    for data, label in loader:
        data = data.cuda()
        label = label.squeeze().type(torch.LongTensor).cuda()
        num_frames = data.size(2)
        per_frame_logits = i3d(data).mean(2)
        preds = per_frame_logits.cpu().detach().numpy().argmax(axis=1) # convert logits into predictions for evaluating accuracy
        
        # calculate and save loss
        loss = F.cross_entropy(per_frame_logits, label)
        losses.append(loss.item()) # append to losses
        ground_truth.extend(list(label.cpu().detach().numpy()))
        predictions.extend(preds.tolist())
        
    metrics = eval_metrics(ground_truth, predictions, num_classes)
    return np.mean(losses), metrics # one loss per epoch and the corresponding metrics
    

Train

In [19]:
# set up gradient descent params

if (l2): # l2 regularization 
    optimizer = optim.SGD(i3d.parameters(), lr=learning_rate, momentum=0.9, weight_decay=wd) # weight_decay = l2 regularization
    lr_sched = optim.lr_scheduler.MultiStepLR(optimizer, [300, 1000])
else: # l1 regularization
    optimizer = optim.SGD(i3d.parameters(), lr=learning_rate, momentum=0.9) 
    lr_sched = optim.lr_scheduler.MultiStepLR(optimizer, [300, 1000])


# save performance
train_losses = []
train_accuracies = []
train_precisions = []
val_losses = []
val_accuracies = []
val_precisions = []

# train
for e in range(num_epochs):
    start_time = time.time()
    
    print("EPOCH", e)
    
    # training
    loss_train, metrics_train = training(model=i3d, optimizer=optimizer, loader=loader_train, num_classes=num_classes, reg_type=l2, ld=lambda1)
    train_losses.append(loss_train)
    train_accuracies.append(metrics_train["accuracy"])
    train_precisions.append(metrics_train["precision"])
    
    print("TRAINING")
    print("Loss", loss_train)
    print("Accuracy", metrics_train["accuracy"])
    print("Precision", metrics_train["precision"])
    
    # validation 
    loss_val, metrics_val = evaluate(model=i3d, loader=loader_val, num_classes=num_classes)
    val_losses.append(loss_val)
    val_accuracies.append(metrics_val["accuracy"])
    val_precisions.append(metrics_val["precision"])
    
    np.savetxt('/home/jt9744/COS429/429_Final/herve_losses/train/train_'+ save_name, np.array(train_losses), delimiter=",")
    np.savetxt('/home/jt9744/COS429/429_Final/herve_losses/val/val_' + save_name, np.array(val_losses), delimiter=",")

    np.savetxt('/home/jt9744/COS429/429_Final/herve_accuracies/train/train_'+save_name, np.array(train_accuracies), delimiter=",")
    np.savetxt('/home/jt9744/COS429/429_Final/herve_accuracies/val/val_'+save_name, np.array(val_accuracies), delimiter=",")

    np.savetxt('/home/jt9744/COS429/429_Final/herve_precisions/train/train_'+save_name, np.array(train_precisions), delimiter=",")
    np.savetxt('/home/jt9744/COS429/429_Final/herve_precisions/val/val_'+save_name, np.array(val_precisions), delimiter=",")
    
    print("VALIDATION")
    print("Loss", loss_val)
    print("Accuracy", metrics_val["accuracy"])
    print("Precision", metrics_val["precision"])
        
    print(f"Time taken for epoch {e}: {(time.time()-start_time)/60} mins")
    print("-----------------------------------------------------------------------")

EPOCH 0
TRAINING
Loss 3.1925109129533213
Accuracy 0.1643075215098529
Precision [0.05882353 0.         0.21448664 0.         0.16795866 0.15677966
 0.10377358 0.12780269 0.1        0.05263158 0.10769231]
VALIDATION
Loss 3.4596562060442837
Accuracy 0.17554076539101499
Precision [0.125      0.         0.25       0.         0.16666667 0.08843537
 0.05769231 0.15254237 0.1025641  0.11111111 0.03846154]
Time taken for epoch 4: 3.6616537531216937 mins
-----------------------------------------------------------------------
EPOCH 5
TRAINING
Loss 3.4739198437027654
Accuracy 0.15570358034970858
Precision [0.03658537 0.         0.19818182 0.         0.14419226 0.14220183
 0.08333333 0.14322251 0.07746479 0.02702703 0.07377049]
VALIDATION
Loss 3.610465953172731
Accuracy 0.13810316139767054
Precision [0.07692308 0.         0.18828452 0.         0.14197531 0.0625
 0.06896552 0.11538462 0.26666667 0.17647059 0.09302326]
Time taken for epoch 5: 3.6585108002026874 mins
----------------------------------

  _warn_prf(average, modifier, msg_start, len(result))


TRAINING
Loss 4.084185120471627
Accuracy 0.1823480432972523
Precision [0.         0.         0.21907601 0.         0.17326733 0.1010101
 0.08910891 0.12722646 0.05479452 0.0625     0.03921569]
VALIDATION
Loss 3.7071762291853094
Accuracy 0.194675540765391
Precision [0.25       0.         0.21428571 0.         0.22058824 0.07407407
 0.14285714 0.         0.0625     0.         0.        ]
Time taken for epoch 11: 3.6507438739140827 mins
-----------------------------------------------------------------------
EPOCH 12


  _warn_prf(average, modifier, msg_start, len(result))


TRAINING
Loss 4.15687751373756
Accuracy 0.1826255897862892
Precision [0.14705882 0.         0.20899719 0.08333333 0.17676143 0.10743802
 0.10843373 0.10631229 0.09259259 0.         0.1025641 ]
VALIDATION
Loss 3.832592469601592
Accuracy 0.19883527454242927
Precision [0.2        0.         0.21397849 0.         0.16083916 0.1875
 0.075      0.22580645 0.11111111 0.         0.        ]
Time taken for epoch 12: 3.6496246973673503 mins
-----------------------------------------------------------------------
EPOCH 13


  _warn_prf(average, modifier, msg_start, len(result))


TRAINING
Loss 4.168100578963261
Accuracy 0.18179295031917847
Precision [0.08571429 0.         0.20041929 0.         0.15857143 0.13483146
 0.10526316 0.16814159 0.05       0.         0.08108108]
VALIDATION
Loss 4.443883483074913
Accuracy 0.194675540765391
Precision [0.25       0.         0.21215352 0.         0.18604651 0.
 0.07142857 0.14814815 0.04       0.         0.        ]
Time taken for epoch 13: 3.687841256459554 mins
-----------------------------------------------------------------------
EPOCH 14
TRAINING
Loss 4.3738785866555085
Accuracy 0.18318068276436303
Precision [0.04545455 0.         0.20911063 0.         0.16047548 0.12871287
 0.08571429 0.12923077 0.07843137 0.         0.10526316]
VALIDATION
Loss 4.9685623754154555
Accuracy 0.1930116472545757
Precision [0.16666667 0.         0.20576132 0.33333333 0.15652174 0.0625
 0.1875     0.13333333 0.125      0.         0.        ]
Time taken for epoch 14: 3.691788983345032 mins
----------------------------------------------------

  _warn_prf(average, modifier, msg_start, len(result))


TRAINING
Loss 4.4307004768102125
Accuracy 0.17790729947266168
Precision [0.         0.         0.20091116 0.         0.15604396 0.1443299
 0.07936508 0.14225941 0.02439024 0.2        0.1       ]
VALIDATION
Loss 4.496937647338741
Accuracy 0.20632279534109818
Precision [0.22222222 0.         0.22133599 0.         0.18181818 0.
 0.15384615 0.         0.08       0.         0.        ]
Time taken for epoch 16: 3.696544154485067 mins
-----------------------------------------------------------------------
EPOCH 17
VALIDATION
Loss 4.365154358966292
Accuracy 0.21381031613976706
Precision [0.         0.         0.2230997  0.         0.16666667 0.07142857
 0.27272727 0.18965517 0.1        0.         0.33333333]
Time taken for epoch 17: 3.6829604268074037 mins
-----------------------------------------------------------------------
EPOCH 18
TRAINING
Loss 4.710883401767699
Accuracy 0.19122953094643352
Precision [0.11111111 0.         0.20793269 0.         0.17138599 0.15384615
 0.11111111 0.11282051

  _warn_prf(average, modifier, msg_start, len(result))


TRAINING
Loss 4.966830331202689
Accuracy 0.18762142658895364
Precision [0.         0.         0.20495406 0.         0.17546848 0.10526316
 0.08333333 0.12784091 0.03225806 0.11111111 0.22727273]


  _warn_prf(average, modifier, msg_start, len(result))


VALIDATION
Loss 5.9782613525705885
Accuracy 0.2129783693843594
Precision [0.         0.         0.22309198 0.         0.14893617 0.21428571
 0.         0.22222222 0.         0.         0.2       ]
Time taken for epoch 20: 3.6868186473846434 mins
-----------------------------------------------------------------------
EPOCH 21


  _warn_prf(average, modifier, msg_start, len(result))


TRAINING
Loss 5.314562183007639
Accuracy 0.18734388009991673
Precision [0.07692308 0.         0.20608108 0.         0.15937149 0.06153846
 0.16666667 0.14117647 0.11538462 0.         0.27272727]


  _warn_prf(average, modifier, msg_start, len(result))


VALIDATION
Loss 6.053430880396819
Accuracy 0.19217970049916805
Precision [0.03703704 0.         0.20591039 0.         0.14285714 0.
 0.09677419 0.15384615 0.         0.         0.        ]
Time taken for epoch 21: 3.6877763708432516 mins
-----------------------------------------------------------------------
EPOCH 22


  _warn_prf(average, modifier, msg_start, len(result))


TRAINING
Loss 5.5315894581934755
Accuracy 0.19539272828198723
Precision [0.16666667 0.         0.20881671 0.         0.17320704 0.10869565
 0.14285714 0.14018692 0.07142857 0.         0.19047619]
VALIDATION
Loss 5.432086864778818
Accuracy 0.2096505823627288
Precision [0.09090909 0.         0.21792453 0.         0.225      0.
 0.         0.16129032 0.         0.         0.5       ]
Time taken for epoch 22: 3.6809189279874164 mins
-----------------------------------------------------------------------
EPOCH 23


  _warn_prf(average, modifier, msg_start, len(result))


TRAINING
Loss 4.833652178666599
Accuracy 0.1945600888148765
Precision [0.07692308 0.         0.20757072 0.         0.17703349 0.15254237
 0.07692308 0.16289593 0.13043478 0.         0.27272727]
VALIDATION
Loss 6.477335900314583
Accuracy 0.2079866888519135
Precision [0.         0.         0.21515435 0.         0.18333333 0.14285714
 0.125      0.19354839 0.         0.         0.        ]
Time taken for epoch 23: 3.7085154175758364 mins
-----------------------------------------------------------------------
EPOCH 24
TRAINING
Loss 4.9689054482531345
Accuracy 0.19511518179295032
Precision [0.25       0.         0.21610845 0.14285714 0.15336658 0.0754717
 0.13953488 0.14414414 0.05882353 0.125      0.26086957]
VALIDATION
Loss 5.592227413634624
Accuracy 0.19550748752079866
Precision [0.         0.         0.20790216 0.         0.12345679 0.03225806
 0.33333333 0.25       0.         0.         0.        ]
Time taken for epoch 24: 3.71129424571991 mins
-----------------------------------------

  _warn_prf(average, modifier, msg_start, len(result))


TRAINING
Loss 5.672179143844879
Accuracy 0.18151540383014156
Precision [0.04761905 0.         0.19682294 0.         0.15059445 0.12244898
 0.11428571 0.13513514 0.05       0.         0.2173913 ]
VALIDATION
Loss 6.269763348516354
Accuracy 0.15806988352745424
Precision [0.         0.         0.21428571 0.         0.15869981 0.14285714
 0.125      0.04       0.         0.         0.33333333]
Time taken for epoch 25: 3.7150313019752503 mins
-----------------------------------------------------------------------
EPOCH 26
TRAINING
Loss 5.411199210423182
Accuracy 0.1806827643630308
Precision [0.18181818 0.         0.20345964 0.         0.15322581 0.09677419
 0.09677419 0.08823529 0.03703704 0.         0.4375    ]
VALIDATION
Loss 6.456326726054357
Accuracy 0.20133111480865223
Precision [0.         0.         0.21243043 0.         0.19047619 0.
 0.         0.25       0.05263158 0.         0.        ]
Time taken for epoch 26: 3.7076921661694846 mins
----------------------------------------------

  _warn_prf(average, modifier, msg_start, len(result))


TRAINING
Loss 5.799484047863292
Accuracy 0.19122953094643352
Precision [0.         0.         0.20742534 0.         0.17253521 0.03508772
 0.0625     0.10852713 0.0625     0.         0.34615385]


  _warn_prf(average, modifier, msg_start, len(result))


VALIDATION
Loss 4.985834568985237
Accuracy 0.16306156405990016
Precision [0.         0.         0.22222222 0.         0.16211293 0.11111111
 0.125      0.15384615 0.         0.         0.        ]
Time taken for epoch 29: 3.7202760140101114 mins
-----------------------------------------------------------------------


  _warn_prf(average, modifier, msg_start, len(result))


In [20]:
print(f"train_losses: {train_losses}")
print(f"val_losses: {val_losses}")
print(f"train_accuracies: {train_accuracies}")
print(f"val_accuracies: {val_accuracies}")

train_losses: [3.85413773766515, 3.110879807921328, 3.041332032541819, 3.111503958041648, 3.1925109129533213, 3.4739198437027654, 3.4290243793392445, 3.517395221625669, 3.8616642281619464, 3.6800238954063267, 3.8776755250391868, 4.084185120471627, 4.15687751373756, 4.168100578963261, 4.3738785866555085, 4.427083623376249, 4.4307004768102125, 4.358303716308193, 4.710883401767699, 4.843249582187621, 4.966830331202689, 5.314562183007639, 5.5315894581934755, 4.833652178666599, 4.9689054482531345, 5.672179143844879, 5.411199210423182, 5.585688322204632, 6.056825450582847, 5.799484047863292]
val_losses: [3.0663892405092223, 2.8594222659907067, 3.1467551997870453, 3.011469419337501, 3.4596562060442837, 3.610465953172731, 3.4818571845362007, 3.7959799244384134, 3.4954152048126725, 4.124250153864711, 3.495139532837986, 3.7071762291853094, 3.832592469601592, 4.443883483074913, 4.9685623754154555, 4.009955929330558, 4.496937647338741, 4.365154358966292, 5.386953786384961, 4.636099676455348, 5.978

Save model

In [21]:
model_path = "/home/jt9744/COS429/429_Final/herve_models_trained/" + save_name 
torch.save(i3d, model_path)

Check saved output

In [22]:
torch.load(model_path)

InceptionI3d(
  (avg_pool): AvgPool3d(kernel_size=[2, 7, 7], stride=(1, 1, 1), padding=0)
  (dropout): Dropout(p=0.5, inplace=False)
  (logits): Unit3D(
    (conv3d): Conv3d(1024, 11, kernel_size=(1, 1, 1), stride=(1, 1, 1))
  )
  (Conv3d_1a_7x7): Unit3D(
    (conv3d): Conv3d(3, 64, kernel_size=(7, 7, 7), stride=(2, 2, 2), bias=False)
    (bn): BatchNorm3d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
  )
  (MaxPool3d_2a_3x3): MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2), padding=0, dilation=1, ceil_mode=False)
  (Conv3d_2b_1x1): Unit3D(
    (conv3d): Conv3d(64, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
    (bn): BatchNorm3d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
  )
  (Conv3d_2c_3x3): Unit3D(
    (conv3d): Conv3d(64, 192, kernel_size=(3, 3, 3), stride=(1, 1, 1), bias=False)
    (bn): BatchNorm3d(192, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
  )
  (MaxPool3d_3a_3x3): MaxPool3dSam