In [None]:
import os
import cv2
import sys
import time
import torch
import random
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [None]:
sys.path.append('../')

In [None]:
from tqdm import tqdm
from turbojpeg import TurboJPEG
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from data_process.ttnet_data_utils import load_raw_img

In [None]:
##  Dataset
from utils.misc import *
from utils.logger import Logger
from config.config import parse_configs
from utils.train_utils import create_optimizer, create_lr_scheduler,  reduce_tensor, to_python_float, get_saved_state, save_checkpoint
from data_process.ttnet_data_utils import train_val_data_separation, get_events_infor

### Phases

In [None]:
configs = parse_configs()
configs.distributed = False
# configs.multitask_learning = True

# Phase 1
configs.smooth_labelling = True
configs.global_weight  = 5
configs.lr_factor =  0.5
configs.saved_fn = 'ttnet_1st_phase'
configs.no_val = True
configs.lr = 0.001 
configs.lr_type ='step_lr' 
configs.lr_step_size = 10 
configs.lr_factor = 0.1
configs.gpu_idx = 0 
configs.global_weight = 5. 
configs.no_event = True
configs.no_local = True
configs.print_freq =  500
configs.batch_size = 24
# configs.sigma =  1.0

#Phase 2
# configs.saved_fn  = 'ttnet_2nd_phase' 
# configs.no-val = True  
# configs.lr = 0.001 
# configs.lr_type = 'step_lr' 
# configs.lr_step_size =  10 
# configs.lr_factor = 0.1 
# configs.gpu_idx = 0 
# configs.global_weight = 0. 
# configs.event_weight = 2. 
# configs.local_weight = 1. 
# configs.pretrained_path ../checkpoints/ttnet_1st_phase/ttnet_1st_phase_epoch_30.pth 
# configs.overwrite_global_2_local  = True
# configs.freeze_global  = True
# configs.smooth-labelling = True


In [None]:
logger = Logger(configs.logs_dir, configs.saved_fn)
logger.info('>>> Created a new logger')
logger.info('>>> configs: {}'.format(configs))

#### Dataloader Exploration

In [None]:
## Dataset Class
class TTNet_Dataset(Dataset):
    def __init__(self, events_infor, org_size, input_size, transform=None, num_samples=None):
        self.events_infor = events_infor
        self.w_org = org_size[0]
        self.h_org = org_size[1]
        self.w_input = input_size[0]
        self.h_input = input_size[1]
        self.w_resize_ratio = self.w_org / self.w_input
        self.h_resize_ratio = self.h_org / self.h_input
        self.transform = transform
        if num_samples is not None:
            self.events_infor = self.events_infor[:num_samples]

    def __len__(self):
        return len(self.events_infor)

    def __resize_ball_pos__(self, ball_pos_xy, w_ratio, h_ratio):
        return np.array([ball_pos_xy[0] / w_ratio, ball_pos_xy[1] / h_ratio])

    def __check_ball_pos__(self, ball_pos_xy, w, h):
        if not ((0 < ball_pos_xy[0] < w) and (0 < ball_pos_xy[1] < h)):
            ball_pos_xy[0] = -1.
            ball_pos_xy[1] = -1.

    def __getitem__(self, index):
        img_path_list, org_ball_pos_xy, target_events = self.events_infor[index]
        self.jpeg_reader = TurboJPEG()  # improve it later (Only initialize it once)
        # Load a sequence of images (-4, 4), resize images before stacking them together
        # Use TurboJPEG to speed up the loading images' phase
        resized_imgs = []
        for img_path in img_path_list:
            in_file = open(img_path, 'rb')
            resized_imgs.append(cv2.resize(self.jpeg_reader.decode(in_file.read(), 0), (self.w_input, self.h_input)))
            in_file.close()
        resized_imgs = np.dstack(resized_imgs)  # (128, 320, 27)
        # Adjust ball pos: full HD --> (320, 128)
        global_ball_pos_xy = self.__resize_ball_pos__(org_ball_pos_xy, self.w_resize_ratio, self.h_resize_ratio)

        # Apply augmentation
        if self.transform:
            resized_imgs, global_ball_pos_xy = self.transform(resized_imgs, global_ball_pos_xy)
        # Adjust ball pos: (320, 128) --> full HD
        org_ball_pos_xy = self.__resize_ball_pos__(global_ball_pos_xy, 1. / self.w_resize_ratio,
                                                   1. / self.h_resize_ratio)
        # If the ball position is outside of the resized image, set position to -1, -1 --> No ball (just for safety)
        self.__check_ball_pos__(org_ball_pos_xy, self.w_org, self.h_org)
        self.__check_ball_pos__(global_ball_pos_xy, self.w_input, self.h_input)

        # Transpose (H, W, C) to (C, H, W) --> fit input of Pytorch model
        resized_imgs = resized_imgs.transpose(2, 0, 1)

        return resized_imgs, org_ball_pos_xy.astype(np.int), global_ball_pos_xy.astype(np.int), target_events

In [None]:
## Transformations 
class Compose(object):
    def __init__(self, transforms, p=1.0):
        self.transforms = transforms
        self.p = p

    def __call__(self, imgs, ball_position_xy):
        if random.random() <= self.p:
            for t in self.transforms:
                imgs, ball_position_xy = t(imgs, ball_position_xy)
        return imgs, ball_position_xy
class Normalize():
    def __init__(self, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), num_frames_sequence=9, p=1.0):
        self.p = p
        self.mean = np.repeat(np.array(mean).reshape(1, 1, 3), repeats=num_frames_sequence, axis=-1)
        self.std = np.repeat(np.array(std).reshape(1, 1, 3), repeats=num_frames_sequence, axis=-1)

    def __call__(self, imgs, ball_position_xy, seg_img):
        if random.random() < self.p:
            imgs = ((imgs / 255.) - self.mean) / self.std

        return imgs, ball_position_xy, seg_img
class Denormalize():
    def __init__(self, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), p=1.0):
        self.p = p
        self.mean = np.array(mean).reshape(1, 1, 3)
        self.std = np.array(std).reshape(1, 1, 3)

    def __call__(self, img):
        img = (img * self.std + self.mean) * 255.
        img = img.astype(np.uint8)

        return img
class Resize(object):
    def __init__(self, new_size, p=0.5, interpolation=cv2.INTER_LINEAR):
        self.new_size = new_size
        self.p = p
        self.interpolation = interpolation

    def __call__(self, imgs, ball_position_xy):
        if random.random() <= self.p:
            h, w, c = imgs.shape
            # Resize a sequence of images
            imgs = cv2.resize(imgs, self.new_size, interpolation=self.interpolation)
            # Dont need to resize seg_img
            # Adjust ball position
            w_ratio = w / self.new_size[0]
            h_ratio = h / self.new_size[1]
            ball_position_xy = np.array([ball_position_xy[0] / w_ratio, ball_position_xy[1] / h_ratio])

        return imgs, ball_position_xy
class Random_Crop(object):
    def __init__(self, max_reduction_percent=0.15, p=0.5, interpolation=cv2.INTER_LINEAR):
        self.max_reduction_percent = max_reduction_percent
        self.p = p
        self.interpolation = interpolation

    def __call__(self, imgs, ball_position_xy):
        # imgs are before resizing
        if random.random() <= self.p:
            h, w, c = imgs.shape
            # Calculate min_x, max_x, min_y, max_y
            remain_percent = random.uniform(1. - self.max_reduction_percent, 1.)
            new_w = remain_percent * w
            min_x = int(random.uniform(0, w - new_w))
            max_x = int(min_x + new_w)
            w_ratio = w / new_w

            new_h = remain_percent * h
            min_y = int(random.uniform(0, h - new_h))
            max_y = int(new_h + min_y)
            h_ratio = h / new_h
            # crop a sequence of images
            imgs = imgs[min_y:max_y, min_x:max_x, :]
            imgs = cv2.resize(imgs, (w, h), interpolation=self.interpolation)

            # Adjust ball position
            ball_position_xy = np.array([(ball_position_xy[0] - min_x) * w_ratio,
                                         (ball_position_xy[1] - min_y) * h_ratio])

        return imgs, ball_position_xy
class Random_Rotate(object):
    def __init__(self, rotation_angle_limit=15, p=0.5):
        self.rotation_angle_limit = rotation_angle_limit
        self.p = p

    def __call__(self, imgs, ball_position_xy):
        if random.random() <= self.p:
            random_angle = random.uniform(-self.rotation_angle_limit, self.rotation_angle_limit)
            # Rotate a sequence of imgs
            h, w, c = imgs.shape
            center = (int(w / 2), int(h / 2))
            rotate_matrix = cv2.getRotationMatrix2D(center, random_angle, 1.)
            imgs = cv2.warpAffine(imgs, rotate_matrix, (w, h), flags=cv2.INTER_LINEAR)

            # Adjust ball position, apply the same rotate_matrix for the sequential images
            ball_position_xy = rotate_matrix.dot(np.array([ball_position_xy[0], ball_position_xy[1], 1.]).T)


        return imgs, ball_position_xy
class Random_HFlip(object):
    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, imgs, ball_position_xy):
        if random.random() <= self.p:
            h, w, c = imgs.shape
            # Horizontal flip a sequence of imgs
            imgs = cv2.flip(imgs, 1)
            # Adjust ball position: Same y, new x = w - x
            ball_position_xy[0] = w - ball_position_xy[0]

        return imgs, ball_position_xy

In [None]:
#Dataloader
def create_train_val_dataloader(configs):
    """Create dataloader for training and validate"""

    train_transform = Compose([
        Random_Crop(max_reduction_percent=0.15, p=0.5),
        Random_HFlip(p=0.5),
        Random_Rotate(rotation_angle_limit=10, p=0.5),
    ], p=1.)

    train_events_infor, val_events_infor, *_ = train_val_data_separation(configs)
    train_dataset = TTNet_Dataset(train_events_infor, configs.org_size, configs.input_size, transform=train_transform,
                                  num_samples=configs.num_samples)
    train_sampler = None
    if configs.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, batch_size=configs.batch_size, shuffle=(train_sampler is None),
                                  pin_memory=configs.pin_memory, num_workers=configs.num_workers, sampler=train_sampler)

    val_dataloader = None
    if not configs.no_val:
        val_transform = None
        val_sampler = None
        val_dataset = TTNet_Dataset(val_events_infor, configs.org_size, configs.input_size, transform=val_transform,
                                    num_samples=configs.num_samples)
        if configs.distributed:
            val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False)
        val_dataloader = DataLoader(val_dataset, batch_size=configs.batch_size, shuffle=False,
                                    pin_memory=configs.pin_memory, num_workers=configs.num_workers, sampler=val_sampler)

    return train_dataloader, val_dataloader, train_sampler
def create_test_dataloader(configs):
    """Create dataloader for testing phase"""

    test_transform = None
    dataset_type = 'test'
    test_events_infor, test_events_labels = get_events_infor(configs.test_game_list, configs, dataset_type)
    test_dataset = TTNet_Dataset(test_events_infor, configs.org_size, configs.input_size, transform=test_transform,
                                 num_samples=configs.num_samples)
    test_sampler = None
    if configs.distributed:
        test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset)
    test_dataloader = DataLoader(test_dataset, batch_size=configs.batch_size, shuffle=False,
                                 pin_memory=configs.pin_memory, num_workers=configs.num_workers, sampler=test_sampler)

    return test_dataloader

In [None]:
def check_dataset():
    configs = parse_configs()
    game_list    = ['game_1']
    dataset_type = 'training'
    train_events_infor, val_events_infor, *_ = train_val_data_separation(configs)
    print('len(train_events_infor): {}'.format(len(train_events_infor)))
    # Test transformation
    transform = Compose([
        Random_Crop(max_reduction_percent=0.15, p=1.),
        Random_HFlip(p=1.),
        Random_Rotate(rotation_angle_limit=15, p=1.)
    ], p=1.)
    ttnet_dataset = TTNet_Dataset(train_events_infor, configs.org_size, configs.input_size, transform=transform)
    print('len(ttnet_dataset): {}'.format(len(ttnet_dataset)))
    example_index = 100
    resized_imgs, org_ball_pos_xy, global_ball_pos_xy, target_event = ttnet_dataset.__getitem__(example_index)
    if 1:
        # Test F.interpolate, we can simply use cv2.resize() to get origin_imgs from resized_imgs
        # Achieve better quality of images and faster
        origin_imgs = F.interpolate(torch.from_numpy(resized_imgs).unsqueeze(0).float(), (1080, 1920))
        origin_imgs = origin_imgs.squeeze().numpy().transpose(1, 2, 0).astype(np.uint8)
        print('F.interpolate - origin_imgs shape: {}'.format(origin_imgs.shape))
        resized_imgs = resized_imgs.transpose(1, 2, 0)
        print('resized_imgs shape: {}'.format(resized_imgs.shape))
    else:
        # Test cv2.resize
        resized_imgs = resized_imgs.transpose(1, 2, 0)
        print('resized_imgs shape: {}'.format(resized_imgs.shape))
        origin_imgs = cv2.resize(resized_imgs, (1920, 1080))
        print('cv2.resize - origin_imgs shape: {}'.format(origin_imgs.shape))
        
    out_images_dir = os.path.join(configs.results_dir, 'debug', 'ttnet_dataset')
    if not os.path.isdir(out_images_dir):
        os.makedirs(out_images_dir)
        
    fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(20, 20))
    axes = axes.ravel()

    for i in range(configs.num_frames_sequence):
        img = origin_imgs[:, :, (i * 3): (i + 1) * 3]
        axes[i].imshow(img)
        axes[i].set_title('image {}'.format(i))
    fig.suptitle(
        'Event: is bounce {}, is net: {}, ball_position_xy: (x= {}, y= {})'.format(target_event[0], target_event[1],
                                                                                   org_ball_pos_xy[0],
                                                                                   org_ball_pos_xy[1]),
        fontsize=16)
    plt.savefig(os.path.join(out_images_dir, 'org_all_imgs_{}.jpg'.format(example_index)))


    for i in range(configs.num_frames_sequence):
        img = resized_imgs[:, :, (i * 3): (i + 1) * 3]
        if (i == (configs.num_frames_sequence - 1)):
            img = cv2.resize(img, (img.shape[1], img.shape[0]))
            ball_img = cv2.circle(img, tuple(global_ball_pos_xy), radius=5, color=(255, 0, 0), thickness=2)
            ball_img = cv2.cvtColor(ball_img, cv2.COLOR_RGB2BGR)
            cv2.imwrite(os.path.join(out_images_dir, 'augment_img_{}.jpg'.format(example_index)),
                        ball_img)

        axes[i].imshow(img)
        axes[i].set_title('image {}'.format(i))
    fig.suptitle(
        'Event: is bounce {}, is net: {}, ball_position_xy: (x= {}, y= {})'.format(target_event[0], target_event[1],
                                                                                   global_ball_pos_xy[0],
                                                                                   global_ball_pos_xy[1]),
        fontsize=16)
    plt.savefig(os.path.join(out_images_dir, 'augment_all_imgs_{}.jpg'.format(example_index)))

In [None]:
# def check_dataloader():
#     configs = parse_configs()
#     configs.distributed = False  # For testing
#     train_dataloader, val_dataloader, train_sampler = create_train_val_dataloader(configs)
#     print('len train_dataloader: {}, val_dataloader: {}'.format(len(train_dataloader), len(val_dataloader)))
#     return train_dataloader

In [None]:
# train_dataloader = check_dataloader()
# for data in train_dataloader:
#     print(len(data))
#     break
# data[0] : stack of images ( 9 frames per batch )
# data[1] : original coordinated of ball
# data[2] : scaled coordinates of ball
# data[3] : event labels

In [None]:
# for i in range(4):
#     print(data[i].shape)

#### Model Exploration

#####  TTNet

<img src="artifacts/network.png" width="1000"/>

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.batchnorm = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)

    def forward(self, x):
        x = self.maxpool(self.relu(self.batchnorm(self.conv(x))))
        return x
class ConvBlock_without_Pooling(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock_without_Pooling, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.batchnorm = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.batchnorm(self.conv(x)))
        return x
class BallDetection(nn.Module):
    def __init__(self, num_frames_sequence, dropout_p):
        super(BallDetection, self).__init__()
        self.conv1 = nn.Conv2d(num_frames_sequence * 3, 64, kernel_size=1, stride=1, padding=0)
        self.batchnorm = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.convblock1 = ConvBlock(in_channels=64, out_channels=64)
        self.convblock2 = ConvBlock(in_channels=64, out_channels=64)
        self.dropout2d = nn.Dropout2d(p=dropout_p)
        self.convblock3 = ConvBlock(in_channels=64, out_channels=128)
        self.convblock4 = ConvBlock(in_channels=128, out_channels=128)
        self.convblock5 = ConvBlock(in_channels=128, out_channels=256)
        self.convblock6 = ConvBlock(in_channels=256, out_channels=256)
        self.fc1 = nn.Linear(in_features=2560, out_features=1792)
        self.fc2 = nn.Linear(in_features=1792, out_features=896)
        self.fc3 = nn.Linear(in_features=896, out_features=448)
        self.dropout1d = nn.Dropout(p=dropout_p)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.relu(self.batchnorm(self.conv1(x)))
        out_block2 = self.convblock2(self.convblock1(x))
        x = self.dropout2d(out_block2)
        out_block3 = self.convblock3(x)
        out_block4 = self.convblock4(out_block3)
        x = self.dropout2d(out_block4)
        out_block5 = self.convblock5(out_block4)
        features = self.convblock6(out_block5)

        x = self.dropout2d(features)
        x = x.contiguous().view(x.size(0), -1)

        x = self.dropout1d(self.relu(self.fc1(x)))
        x = self.dropout1d(self.relu(self.fc2(x)))
        out = self.sigmoid(self.fc3(x))

        return out, features#, out_block2, out_block3, out_block4, out_block5
class EventsSpotting(nn.Module):
    def __init__(self, dropout_p):
        super(EventsSpotting, self).__init__()
        self.conv1 = nn.Conv2d(512, 64, kernel_size=1, stride=1, padding=0)
        self.batchnorm = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.dropout2d = nn.Dropout2d(p=dropout_p)
        self.convblock = ConvBlock_without_Pooling(in_channels=64, out_channels=64)
        self.fc1 = nn.Linear(in_features=640, out_features=512)
        self.fc2 = nn.Linear(in_features=512, out_features=2)
        self.sigmoid = nn.Sigmoid()

    def forward(self, global_features, local_features):
        input_eventspotting = torch.cat((global_features, local_features), dim=1)
        x = self.relu(self.batchnorm(self.conv1(input_eventspotting)))
        x = self.dropout2d(x)
        x = self.convblock(x)
        x = self.dropout2d(x)
        x = self.convblock(x)
        x = self.dropout2d(x)

        x = x.contiguous().view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        out = self.sigmoid(self.fc2(x))

        return out
class TTNet(nn.Module):
    def __init__(self, dropout_p, tasks, input_size, thresh_ball_pos_mask, num_frames_sequence,
                 mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        super(TTNet, self).__init__()
        self.tasks = tasks
        self.ball_local_stage, self.events_spotting = None, None, 
        self.ball_global_stage = BallDetection(num_frames_sequence=num_frames_sequence, dropout_p=dropout_p)
        if 'local' in tasks:
            self.ball_local_stage = BallDetection(num_frames_sequence=num_frames_sequence, dropout_p=dropout_p)
        if 'event' in tasks:
            self.events_spotting = EventsSpotting(dropout_p=dropout_p)
        self.w_resize = input_size[0]
        self.h_resize = input_size[1]
        self.thresh_ball_pos_mask = thresh_ball_pos_mask
        self.mean = torch.repeat_interleave(torch.tensor(mean).view(1, 3, 1, 1), repeats=9, dim=1)
        self.std = torch.repeat_interleave(torch.tensor(std).view(1, 3, 1, 1), repeats=9, dim=1)

    def forward(self, resize_batch_input, org_ball_pos_xy):
        """Forward propagation
        :param resize_batch_input: (batch_size, 27, 128, 320)
        :param org_ball_pos_xy: (batch_size, 2) --> Use it to get ground-truth for the local stage
        :return:
        """
        pred_ball_local, pred_events, local_ball_pos_xy = None, None, None

        # Normalize the input before compute forward propagation
        pred_ball_global, global_features = self.ball_global_stage(
            self.__normalize__(resize_batch_input))
        if self.ball_local_stage is not None:
            # Based on the prediction of the global stage, crop the original images
            input_ball_local, cropped_params = self.__crop_original_batch__(resize_batch_input, pred_ball_global)
            # Get the ground truth of the ball for the local stage
            local_ball_pos_xy = self.__get_groundtruth_local_ball_pos__(org_ball_pos_xy, cropped_params)
            # Normalize the input before compute forward propagation
            pred_ball_local, local_features  = self.ball_local_stage(self.__normalize__(input_ball_local))
            # Only consider the events spotting if the model has the local stage for ball detection
            if self.events_spotting is not None:
                pred_events = self.events_spotting(global_features, local_features)

        return pred_ball_global, pred_ball_local, pred_events, local_ball_pos_xy

    def run_demo(self, resize_batch_input):
        """Only for full 4 stages/modules in TTNet"""

        # Normalize the input before compute forward propagation
        pred_ball_global, global_features, out_block2, out_block3, out_block4, out_block5 = self.ball_global_stage(
            self.__normalize__(resize_batch_input))
        input_ball_local, cropped_params = self.__crop_original_batch__(resize_batch_input, pred_ball_global)
        # Normalize the input before compute forward propagation
        pred_ball_local, local_features, *_ = self.ball_local_stage(self.__normalize__(input_ball_local))
        pred_events = self.events_spotting(global_features, local_features)
        pred_seg = self.segmentation(out_block2, out_block3, out_block4, out_block5)

        return pred_ball_global, pred_ball_local, pred_events, pred_seg

    def __normalize__(self, x):
        if not self.mean.is_cuda:
            self.mean = self.mean.cuda()
            self.std = self.std.cuda()

        return (x / 255. - self.mean) / self.std

    def __get_groundtruth_local_ball_pos__(self, org_ball_pos_xy, cropped_params):
        local_ball_pos_xy = torch.zeros_like(org_ball_pos_xy)  # no grad for torch.zeros_like output

        for idx, params in enumerate(cropped_params):
            is_ball_detected, x_min, x_max, y_min, y_max, x_pad, y_pad = params

            if is_ball_detected:
                # Get the local ball position based on the crop image informaion
                local_ball_pos_xy[idx, 0] = max(org_ball_pos_xy[idx, 0] - x_min + x_pad, -1)
                local_ball_pos_xy[idx, 1] = max(org_ball_pos_xy[idx, 1] - y_min + y_pad, -1)
                # If the ball is outside of the cropped image --> set position to -1, -1 --> No ball
                if (local_ball_pos_xy[idx, 0] >= self.w_resize) or (local_ball_pos_xy[idx, 1] >= self.h_resize) or (
                        local_ball_pos_xy[idx, 0] < 0) or (local_ball_pos_xy[idx, 1] < 0):
                    local_ball_pos_xy[idx, 0] = -1
                    local_ball_pos_xy[idx, 1] = -1
            else:
                local_ball_pos_xy[idx, 0] = -1
                local_ball_pos_xy[idx, 1] = -1
        return local_ball_pos_xy

    def __crop_original_batch__(self, resize_batch_input, pred_ball_global):
        """Get input of the local stage by cropping the original images based on the predicted ball position
            of the global stage
        :param resize_batch_input: (batch_size, 27, 128, 320)
        :param pred_ball_global: (batch_size, 448)
        :param org_ball_pos_xy: (batch_size, 2)
        :return: input_ball_local (batch_size, 27, 128, 320)
        """
        # Process input for local stage based on output of the global one

        batch_size = resize_batch_input.size(0)
        h_original, w_original = 1080, 1920
        h_ratio = h_original / self.h_resize
        w_ratio = w_original / self.w_resize
        pred_ball_global_mask = pred_ball_global.clone().detach()
        pred_ball_global_mask[pred_ball_global_mask < self.thresh_ball_pos_mask] = 0.

        # Crop the original images
        input_ball_local = torch.zeros_like(resize_batch_input)  # same shape with resize_batch_input, no grad
        original_batch_input = F.interpolate(resize_batch_input, (h_original, w_original))  # On GPU
        cropped_params = []
        for idx in range(batch_size):
            pred_ball_pos_x = pred_ball_global_mask[idx, :self.w_resize]
            pred_ball_pos_y = pred_ball_global_mask[idx, self.w_resize:]
            # If the ball is not detected, we crop the center of the images, set ball_poss to [-1, -1]
            if (torch.sum(pred_ball_pos_x) == 0.) or (torch.sum(pred_ball_pos_y) == 0.):
                # Assume the ball is in the center image
                x_center = int(self.w_resize / 2)
                y_center = int(self.h_resize / 2)
                is_ball_detected = False
            else:
                x_center = torch.argmax(pred_ball_pos_x)  # Upper part
                y_center = torch.argmax(pred_ball_pos_y)  # Lower part
                is_ball_detected = True

            # Adjust ball position to the original size
            x_center = int(x_center * w_ratio)
            y_center = int(y_center * h_ratio)

            x_min, x_max, y_min, y_max = self.__get_crop_params__(x_center, y_center, self.w_resize, self.h_resize,
                                                                  w_original, h_original)
            # Put image to the center
            h_crop = y_max - y_min
            w_crop = x_max - x_min
            x_pad = 0
            y_pad = 0
            if (h_crop != self.h_resize) or (w_crop != self.w_resize):
                x_pad = int((self.w_resize - w_crop) / 2)
                y_pad = int((self.h_resize - h_crop) / 2)
                input_ball_local[idx, :, y_pad:(y_pad + h_crop), x_pad:(x_pad + w_crop)] = original_batch_input[idx, :,
                                                                                           y_min:y_max, x_min: x_max]
            else:
                input_ball_local[idx, :, :, :] = original_batch_input[idx, :, y_min:y_max, x_min: x_max]
            cropped_params.append([is_ball_detected, x_min, x_max, y_min, y_max, x_pad, y_pad])

        return input_ball_local, cropped_params

    def __get_crop_params__(self, x_center, y_center, w_resize, h_resize, w_original, h_original):
        x_min = max(0, x_center - int(w_resize / 2))
        y_min = max(0, y_center - int(h_resize / 2))

        x_max = min(w_original, x_min + w_resize)
        y_max = min(h_original, y_min + h_resize)

        return x_min, x_max, y_min, y_max

In [None]:
# def check_ttnet():
#     tasks = ['global', 'local', 'event']
#     ttnet = TTNet(dropout_p=0.5, tasks=tasks, input_size=(320, 128), thresh_ball_pos_mask=0.01,
#                   num_frames_sequence=9).cuda()
#     resize_batch_input = torch.rand((1, 27, 128, 320)).cuda()
#     org_ball_pos_xy = torch.rand((1, 2)).cuda()
#     start = time.time()
    
#     pred_ball_global, pred_ball_local, pred_events, local_ball_pos_xy = ttnet(resize_batch_input, org_ball_pos_xy)
# #     print("DEBUG Unbalaced loss: ", pred_ball_global.shape, pred_ball_local.shape, pred_events.shape, local_ball_pos_xy.shape)    
        
#     if pred_ball_global is not None:
#         print('pred_ball_global: {}'.format(pred_ball_global.size()))
#     if pred_ball_local is not None:
#         print('pred_ball_local: {}'.format(pred_ball_local.size()))
#     if pred_events is not None:
#         print('pred_events: {}'.format(pred_events.size()))
#     print('local_ball_pos_xy: {}'.format(local_ball_pos_xy.size()))

In [None]:
# check_ttnet()

#### Loss Functions

In [None]:
def create_target_ball(ball_position_xy, sigma, w, h, thresh_mask, device):
    """Create target for the ball detection stages

    :param ball_position_xy: Position of the ball (x,y)
    :param sigma: standard deviation (a hyperparameter)
    :param w: width of the resize image
    :param h: height of the resize image
    :param thresh_mask: if values of 1D Gaussian < thresh_mask --> set to 0 to reduce computation
    :param device: cuda() or cpu()
    :return:
    """
    w, h = int(w), int(h)
    target_ball_position = torch.zeros((w + h,), device=device)
    # Only do the next step if the ball is existed
    if (w > ball_position_xy[0] > 0) and (h > ball_position_xy[1] > 0):
        # For x
        x_pos = torch.arange(0, w, device=device)
        target_ball_position[:w] = gaussian_1d(x_pos, ball_position_xy[0], sigma=sigma)
        # For y
        y_pos = torch.arange(0, h, device=device)
        target_ball_position[w:] = gaussian_1d(y_pos, ball_position_xy[1], sigma=sigma)

        target_ball_position[target_ball_position < thresh_mask] = 0.

    return target_ball_position
def gaussian_1d(pos, muy, sigma):
    """Create 1D Gaussian distribution based on ball position (muy), and std (sigma)"""
    target = torch.exp(- (((pos - muy) / sigma) ** 2) / 2)
    return target

In [None]:
class Ball_Detection_Loss(nn.Module):
    def __init__(self, w, h, epsilon=1e-9):
        super(Ball_Detection_Loss, self).__init__()
        self.w = w
        self.h = h
        self.epsilon = epsilon

    def forward(self, pred_ball_position, target_ball_position):
        x_pred = pred_ball_position[:, :self.w]
        y_pred = pred_ball_position[:, self.w:]

        x_target = target_ball_position[:, :self.w]
        y_target = target_ball_position[:, self.w:]

        loss_ball_x = - torch.mean(x_target * torch.log(x_pred + self.epsilon) + (1 - x_target) * torch.log(1 - x_pred + self.epsilon))
        loss_ball_y = - torch.mean(y_target * torch.log(y_pred + self.epsilon) + (1 - y_target) * torch.log(1 - y_pred + self.epsilon))

        return loss_ball_x + loss_ball_y
class Events_Spotting_Loss(nn.Module):
    def __init__(self, weights=(1, 3), num_events=2, epsilon=1e-9):
        super(Events_Spotting_Loss, self).__init__()
        self.weights = torch.tensor(weights).view(1, 2)
        self.weights = self.weights / self.weights.sum()
        self.num_events = num_events
        self.epsilon = epsilon

    def forward(self, pred_events, target_events):
        self.weights = self.weights.cuda()
        return - torch.mean(self.weights * (target_events * torch.log(pred_events + self.epsilon) + (1. - target_events) * torch.log(1 - pred_events + self.epsilon)))

In [None]:
class Unbalance_Loss_Model(nn.Module):
    def __init__(self, model, tasks_loss_weight, weights_events, input_size, sigma, thresh_ball_pos_mask, device):
        super(Unbalance_Loss_Model, self).__init__()
        self.model = model
        self.tasks_loss_weight = torch.tensor(tasks_loss_weight)
        self.tasks_loss_weight = self.tasks_loss_weight / self.tasks_loss_weight.sum()
        self.num_events = len(tasks_loss_weight)
        self.w = input_size[0]
        self.h = input_size[1]
        self.sigma = sigma
        self.thresh_ball_pos_mask = thresh_ball_pos_mask
        self.device = device
        self.ball_loss_criterion = Ball_Detection_Loss(self.w, self.h)
        self.event_loss_criterion = Events_Spotting_Loss(weights=weights_events, num_events=self.num_events)

    def forward(self, resize_batch_input, org_ball_pos_xy, global_ball_pos_xy, target_events):
        pred_ball_global, pred_ball_local, pred_events, local_ball_pos_xy = self.model(resize_batch_input, org_ball_pos_xy)
        # Create target for events spotting and ball position (local and global)
        batch_size = pred_ball_global.size(0)
        target_ball_global = torch.zeros_like(pred_ball_global)
        task_idx = 0
        for sample_idx in range(batch_size):
            target_ball_global[sample_idx] = create_target_ball(global_ball_pos_xy[sample_idx], sigma=self.sigma,
                                                                w=self.w, h=self.h,
                                                                thresh_mask=self.thresh_ball_pos_mask,
                                                                device=self.device)
        global_ball_loss = self.ball_loss_criterion(pred_ball_global, target_ball_global)
        total_loss = global_ball_loss * self.tasks_loss_weight[task_idx]

        if pred_ball_local is not None:
            task_idx += 1
            target_ball_local = torch.zeros_like(pred_ball_local)
            for sample_idx in range(batch_size):
                target_ball_local[sample_idx] = create_target_ball(local_ball_pos_xy[sample_idx], sigma=self.sigma,
                                                                   w=self.w, h=self.h,
                                                                   thresh_mask=self.thresh_ball_pos_mask,
                                                                   device=self.device)
            local_ball_loss = self.ball_loss_criterion(pred_ball_local, target_ball_local)
            total_loss += local_ball_loss * self.tasks_loss_weight[task_idx]

        if pred_events is not None:
            task_idx += 1
            target_events = target_events.to(device=self.device)
            event_loss = self.event_loss_criterion(pred_events, target_events)
            total_loss += event_loss * self.tasks_loss_weight[task_idx]


        return pred_ball_global, pred_ball_local, pred_events, local_ball_pos_xy, total_loss, None

    def run_demo(self, resize_batch_input):
        pred_ball_global, pred_ball_local, pred_events = self.model.run_demo(resize_batch_input)
        return pred_ball_global, pred_ball_local, pred_events
class Multi_Task_Learning_Model(nn.Module):
    """
    Original paper: "Multi-task learning using uncertainty to weigh losses for scene geometry and semantics" - CVPR 2018
    url: https://arxiv.org/pdf/1705.07115.pdf
    refer code: https://github.com/Hui-Li/multi-task-learning-example-PyTorch
    """

    def __init__(self, model, tasks, num_events, weights_events, input_size, sigma, thresh_ball_pos_mask, device):
        super(Multi_Task_Learning_Model, self).__init__()
        self.model = model
        self.tasks = tasks
        self.num_tasks = len(tasks)
        self.log_vars = nn.Parameter(torch.zeros((self.num_tasks)))
        self.w = input_size[0]
        self.h = input_size[1]
        self.sigma = sigma
        self.thresh_ball_pos_mask = thresh_ball_pos_mask
        self.device = device
        self.ball_loss_criterion = Ball_Detection_Loss(self.w, self.h)
        self.event_loss_criterion = Events_Spotting_Loss(weights=weights_events, num_events=num_events)
        print("Learnable Multitask model")
        # self.seg_loss_criterion = Segmentation_Loss()

    def forward(self, resize_batch_input, org_ball_pos_xy, global_ball_pos_xy, target_events):#, target_seg):
        log_vars_idx = 0
        # pred_ball_global, pred_ball_local, pred_events, pred_seg, local_ball_pos_xy = self.model(resize_batch_input,org_ball_pos_xy)
        pred_ball_global, pred_ball_local, pred_events, local_ball_pos_xy = self.model(resize_batch_input,org_ball_pos_xy)
        
        # Create target for events spotting and ball position (local and global)
        batch_size = pred_ball_global.size(0)
        target_ball_global = torch.zeros_like(pred_ball_global)
        for sample_idx in range(batch_size):
            target_ball_global[sample_idx] = create_target_ball(global_ball_pos_xy[sample_idx], sigma=self.sigma,
                                                                w=self.w, h=self.h,
                                                                thresh_mask=self.thresh_ball_pos_mask,
                                                                device=self.device)
        global_ball_loss = self.ball_loss_criterion(pred_ball_global, target_ball_global)
        total_loss = global_ball_loss / (torch.exp(2 * self.log_vars[log_vars_idx])) + self.log_vars[log_vars_idx]

        if pred_ball_local is not None:
            log_vars_idx += 1
            target_ball_local = torch.zeros_like(pred_ball_local)
            for sample_idx in range(batch_size):
                target_ball_local[sample_idx] = create_target_ball(local_ball_pos_xy[sample_idx], sigma=self.sigma,
                                                                   w=self.w, h=self.h,
                                                                   thresh_mask=self.thresh_ball_pos_mask,
                                                                   device=self.device)
            local_ball_loss = self.ball_loss_criterion(pred_ball_local, target_ball_local)
            total_loss += local_ball_loss / (torch.exp(2 * self.log_vars[log_vars_idx])) + self.log_vars[log_vars_idx]

        if pred_events is not None:
            log_vars_idx += 1
            target_events = target_events.to(device=self.device)
            event_loss = self.event_loss_criterion(pred_events, target_events)
            total_loss += event_loss / (2 * torch.exp(self.log_vars[log_vars_idx])) + self.log_vars[log_vars_idx]

        # if pred_seg is not None:
        #     log_vars_idx += 1
        #     seg_loss = self.seg_loss_criterion(pred_seg, target_seg)
        #     total_loss += seg_loss / (2 * torch.exp(self.log_vars[log_vars_idx])) + self.log_vars[log_vars_idx]

        # Final weights: [math.exp(log_var) ** 0.5 for log_var in log_vars]

        #return pred_ball_global, pred_ball_local, pred_events, pred_seg, local_ball_pos_xy, total_loss, self.log_vars.data.tolist()
        return pred_ball_global, pred_ball_local, pred_events, local_ball_pos_xy, total_loss, self.log_vars.data.tolist()

    def run_demo(self, resize_batch_input):
        # pred_ball_global, pred_ball_local, pred_events, pred_seg = self.model.run_demo(resize_batch_input)
        # return pred_ball_global, pred_ball_local, pred_events, pred_seg
        pred_ball_global, pred_ball_local, pred_events = self.model.run_demo(resize_batch_input)
        return pred_ball_global, pred_ball_local, pred_events

##### Model

In [None]:
# configs.thresh_ball_pos_mask

In [None]:
## Model defination
def create_model(configs):
    """Create model based on architecture name"""
    if configs.arch == 'ttnet':
        ttnet_model = TTNet(dropout_p=configs.dropout_p, tasks=configs.tasks, input_size=configs.input_size,
                            thresh_ball_pos_mask=configs.thresh_ball_pos_mask,
                            num_frames_sequence=configs.num_frames_sequence)
    else:
        assert False, 'Undefined model backbone'

    if configs.multitask_learning == True:
        model = Multi_Task_Learning_Model(ttnet_model, tasks=configs.tasks, num_events=configs.num_events,
                                          weights_events=configs.events_weights_loss,
                                          input_size=configs.input_size, sigma=configs.sigma,
                                          thresh_ball_pos_mask=configs.thresh_ball_pos_mask, device=configs.device)
    else:
        model = Unbalance_Loss_Model(ttnet_model, tasks_loss_weight=configs.tasks_loss_weight,
                                     weights_events=configs.events_weights_loss, input_size=configs.input_size,
                                     sigma=configs.sigma, thresh_ball_pos_mask=configs.thresh_ball_pos_mask,
                                     device=configs.device)

    return model

def get_num_parameters(model):
    """Count number of trained parameters of the model"""
    if hasattr(model, 'module'):
        num_parameters = sum(p.numel() for p in model.module.parameters() if p.requires_grad)
    else:
        num_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)

    return num_parameters

def freeze_model(model, freeze_modules_list):
    """Freeze modules of the model based on the configuration"""
    for layer_name, p in model.named_parameters():
        p.requires_grad = True
        for freeze_module in freeze_modules_list:
            if freeze_module in layer_name:
                p.requires_grad = False
                break

    return model

def load_weights_local_stage(pretrained_dict):
    """Overwrite the weights of the global stage to the local stage"""
    local_weights_dict = {}
    for layer_name, v in pretrained_dict.items():
        if 'ball_global_stage' in layer_name:
            layer_name_parts = layer_name.split('.')
            layer_name_parts[1] = 'ball_local_stage'
            local_name = '.'.join(layer_name_parts)
            local_weights_dict[local_name] = v

    return {**pretrained_dict, **local_weights_dict}

def load_pretrained_model(model, pretrained_path, gpu_idx, overwrite_global_2_local):
    """Load weights from the pretrained model"""
    assert os.path.isfile(pretrained_path), "=> no checkpoint found at '{}'".format(pretrained_path)
    if gpu_idx is None:
        checkpoint = torch.load(pretrained_path, map_location='cpu')
    else:
        # Map model to be loaded to specified single gpu.
        loc = 'cuda:{}'.format(gpu_idx)
        checkpoint = torch.load(pretrained_path, map_location=loc)
    pretrained_dict = checkpoint['state_dict']
    if hasattr(model, 'module'):
        model_state_dict = model.module.state_dict()
        # 1. filter out unnecessary keys
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_state_dict}
        # Load global to local stage
        if overwrite_global_2_local:
            pretrained_dict = load_weights_local_stage(pretrained_dict)
        # 2. overwrite entries in the existing state dict
        model_state_dict.update(pretrained_dict)
        # 3. load the new state dict
        model.module.load_state_dict(model_state_dict)
    else:
        model_state_dict = model.state_dict()
        # 1. filter out unnecessary keys
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_state_dict}
        # Load global to local stage
        if overwrite_global_2_local:
            pretrained_dict = load_weights_local_stage(pretrained_dict)
        # 2. overwrite entries in the existing state dict
        model_state_dict.update(pretrained_dict)
        # 3. load the new state dict
        model.load_state_dict(model_state_dict)
    return model

def resume_model(resume_path, arch, gpu_idx):
    """Resume training model from the previous trained checkpoint"""
    assert os.path.isfile(resume_path), "=> no checkpoint found at '{}'".format(resume_path)
    if gpu_idx is None:
        checkpoint = torch.load(resume_path, map_location='cpu')
    else:
        # Map model to be loaded to specified single gpu.
        loc = 'cuda:{}'.format(gpu_idx)
        checkpoint = torch.load(resume_path, map_location=loc)
    assert arch == checkpoint['configs'].arch, "Load the different arch..."
    print("=> loaded checkpoint '{}' (epoch {})".format(resume_path, checkpoint['epoch']))

    return checkpoint

def make_data_parallel(model, configs):
    if configs.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if configs.gpu_idx is not None:
            torch.cuda.set_device(configs.gpu_idx)
            model.cuda(configs.gpu_idx)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            configs.batch_size = int(configs.batch_size / configs.ngpus_per_node)
            configs.num_workers = int((configs.num_workers + configs.ngpus_per_node - 1) / configs.ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[configs.gpu_idx],
                                                              find_unused_parameters=True)
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif configs.gpu_idx is not None:
        torch.cuda.set_device(configs.gpu_idx)
        model = model.cuda(configs.gpu_idx)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        model = torch.nn.DataParallel(model).cuda()

    return model

In [None]:
def check_model(configs):
    model = create_model(configs)
    return model

In [None]:
def make_data_parallel(model, configs):
#     if configs.distributed:
#         # For multiprocessing distributed, DistributedDataParallel constructor
#         # should always set the single device scope, otherwise,
#         # DistributedDataParallel will use all available devices.
#         if configs.gpu_idx is not None:
#             torch.cuda.set_device(configs.gpu_idx)
#             model.cuda(configs.gpu_idx)
#             # When using a single GPU per process and per
#             # DistributedDataParallel, we need to divide the batch size
#             # ourselves based on the total number of GPUs we have
#             configs.batch_size = int(configs.batch_size / configs.ngpus_per_node)
#             configs.num_workers = int((configs.num_workers + configs.ngpus_per_node - 1) / configs.ngpus_per_node)
#             model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[configs.gpu_idx],
#                                                               find_unused_parameters=True)
#         else:
#             model.cuda()
#             # DistributedDataParallel will divide and allocate batch_size to all
#             # available GPUs if device_ids are not set
#             model = torch.nn.parallel.DistributedDataParallel(model)
    if configs.gpu_idx is not None:
        torch.cuda.set_device(configs.gpu_idx)
        model = model.cuda(configs.gpu_idx)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        model = torch.nn.DataParallel(model).cuda()

    return model
def freeze_model(model, freeze_modules_list):
    """Freeze modules of the model based on the configuration"""
    for layer_name, p in model.named_parameters():
        p.requires_grad = True
        for freeze_module in freeze_modules_list:
            if freeze_module in layer_name:
                p.requires_grad = False
                break

    return model

In [None]:
def evaluate_one_epoch(val_loader, model, epoch, configs, logger=None):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')

    progress = ProgressMeter(len(val_loader), [batch_time, data_time, losses],
                             prefix="Evaluate - Epoch: [{}/{}]".format(epoch, configs.num_epochs))
    # switch to evaluate mode
    model.eval()
    with torch.no_grad():
        start_time = time.time()
        for batch_idx, (resized_imgs, org_ball_pos_xy, global_ball_pos_xy, target_events) in enumerate(
                tqdm(val_loader)):
            data_time.update(time.time() - start_time)
            batch_size = resized_imgs.size(0)
#             target_seg = target_seg.to(configs.device, non_blocking=True)
            resized_imgs = resized_imgs.to(configs.device, non_blocking=True).float()
            pred_ball_global, pred_ball_local, pred_events, local_ball_pos_xy, total_loss, _ = model(
                resized_imgs, org_ball_pos_xy, global_ball_pos_xy, target_events)

            # For torch.nn.DataParallel case
            if (not configs.distributed) and (configs.gpu_idx is None):
                total_loss = torch.mean(total_loss)

            if configs.distributed:
                reduced_loss = reduce_tensor(total_loss.data, configs.world_size)
            else:
                reduced_loss = total_loss.data
            losses.update(to_python_float(reduced_loss), batch_size)
            # measure elapsed time
            torch.cuda.synchronize()
            batch_time.update(time.time() - start_time)

            # Log message
            if logger is not None:
                if ((batch_idx + 1) % configs.print_freq) == 0:
                    logger.info(progress.get_message(batch_idx))

            start_time = time.time()

    return losses.avg

##### Training Debugger

In [None]:
model = create_model(configs)
model = make_data_parallel(model, configs)
model = freeze_model(model, configs.freeze_modules_list)

In [None]:
print("Total Parameters: ", get_num_parameters(model) / 1000000 , "M")

In [None]:
optimizer       = create_optimizer(configs, model)
lr_scheduler    = create_lr_scheduler(optimizer, configs)
best_val_loss   = np.inf
earlystop_count = 0
is_best         = False

In [None]:
print(configs.pretrained_path, configs.resume_path)

In [None]:
configs.resume_path = "../../checkpoints/ttnet/ttnet_1st_phase_epoch_6.pth"

In [None]:
if configs.pretrained_path is not None:
    model = load_pretrained_model(model, configs.pretrained_path, gpu_idx, configs.overwrite_global_2_local)
    if logger is not None:
        logger.info('loaded pretrained model at {}'.format(configs.pretrained_path))

In [None]:
# optionally resume from a checkpoint
if configs.resume_path is not None:
    checkpoint = resume_model(configs.resume_path, configs.arch, configs.gpu_idx)
    if hasattr(model, 'module'):
        model.module.load_state_dict(checkpoint['state_dict'])
    else:
        model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
    best_val_loss = checkpoint['best_val_loss']
    earlystop_count = checkpoint['earlystop_count']
    configs.start_epoch = checkpoint['epoch'] + 1

In [None]:
if logger is not None:
    logger.info(">>> Loading dataset & getting dataloader...")
train_loader, val_loader, train_sampler = create_train_val_dataloader(configs)
test_loader = create_test_dataloader(configs)

In [None]:
if logger is not None:
    logger.info('number of batches in train set: {}'.format(len(train_loader)))
    if val_loader is not None:
        logger.info('number of batches in val set: {}'.format(len(val_loader)))
    logger.info('number of batches in test set: {}'.format(len(test_loader)))
if configs.evaluate:
    assert val_loader is not None, "The validation should not be None"
    val_loss = evaluate_one_epoch(val_loader, model, configs.start_epoch - 1, configs, logger)
    print('Evaluate, val_loss: {}'.format(val_loss))
    

##### Training

In [None]:
def train_one_epoch(train_loader, model, optimizer, epoch, configs, logger):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')

    progress = ProgressMeter(len(train_loader), [batch_time, data_time, losses],
                             prefix="Train - Epoch: [{}/{}]".format(epoch, configs.num_epochs))

    # switch to train mode
    model.train()
    start_time = time.time()
    for batch_idx, (resized_imgs, org_ball_pos_xy, global_ball_pos_xy, target_events) in enumerate(
            tqdm(train_loader)):
        data_time.update(time.time() - start_time)
        batch_size = resized_imgs.size(0)
        resized_imgs = resized_imgs.to(configs.device, non_blocking=True).float()
        pred_ball_global, pred_ball_local, pred_events, local_ball_pos_xy, total_loss, _ = model(resized_imgs, org_ball_pos_xy, global_ball_pos_xy, target_events)
        # For torch.nn.DataParallel case
        if (not configs.distributed) and (configs.gpu_idx is None):
            total_loss = torch.mean(total_loss)

        # zero the parameter gradients
        optimizer.zero_grad()
        # compute gradient and perform backpropagation
        total_loss.backward()
        optimizer.step()

        if configs.distributed:
            reduced_loss = reduce_tensor(total_loss.data, configs.world_size)
        else:
            reduced_loss = total_loss.data
        losses.update(to_python_float(reduced_loss), batch_size)
        # measure elapsed time
        torch.cuda.synchronize()
        batch_time.update(time.time() - start_time)

        # Log message
        if logger is not None:
            if ((batch_idx + 1) % configs.print_freq) == 0:
                logger.info(progress.get_message(batch_idx))

        start_time = time.time()

    return losses.avg

In [None]:
for epoch in range(configs.start_epoch, configs.num_epochs + 1):
#     Get the current learning rate
    for param_group in optimizer.param_groups:
        lr = param_group['lr']
    if logger is not None:
        logger.info('{}'.format('*-' * 40))
        logger.info('{} {}/{} {}'.format('=' * 35, epoch, configs.num_epochs, '=' * 35))
        logger.info('{}'.format('*-' * 40))
        logger.info('>>> Epoch: [{}/{}] learning rate: {:.2e}'.format(epoch, configs.num_epochs, lr))

    if configs.distributed:
        train_sampler.set_epoch(epoch)
    # train for one epoch
    train_loss = train_one_epoch(train_loader, model, optimizer, epoch, configs, logger)
    loss_dict = {'train': train_loss}
    if not configs.no_val:
        val_loss = evaluate_one_epoch(val_loader, model, epoch, configs, logger)
        is_best = val_loss <= best_val_loss
        best_val_loss = min(val_loss, best_val_loss)
        loss_dict['val'] = val_loss

    if not configs.no_test:
        test_loss = evaluate_one_epoch(test_loader, model, epoch, configs, logger)
        loss_dict['test'] = test_loss
#     Write tensorboard
#     Save checkpoint
    if (is_best or ((epoch % configs.checkpoint_freq) == 0)):
        saved_state = get_saved_state(model, optimizer, lr_scheduler, epoch, configs, best_val_loss,
                                      earlystop_count)
        save_checkpoint(configs.checkpoints_dir, configs.saved_fn, saved_state, is_best, epoch)
    # Check early stop training
    if configs.earlystop_patience is not None:
        earlystop_count = 0 if is_best else (earlystop_count + 1)
        print_string = ' |||\t earlystop_count: {}'.format(earlystop_count)
        if configs.earlystop_patience <= earlystop_count:
            print_string += '\n\t--- Early stopping!!!'
            break
        else:
            print_string += '\n\t--- Continue training..., earlystop_count: {}'.format(earlystop_count)
        if logger is not None:
            logger.info(print_string)
    # Adjust learning rate
    if configs.lr_type == 'plateau':
        assert (not configs.no_val), "Only use plateau when having validation set"
        lr_scheduler.step(val_loss)
    else:
        lr_scheduler.step()

### Phase 2

In [None]:
configs = parse_configs()
configs.distributed = False
# configs.multitask_learning = True

#Phase 2
configs.saved_fn  = 'ttnet_2nd_phase' 
configs.no-val = True  
configs.lr = 0.001 
configs.lr_type = 'step_lr' 
configs.lr_step_size =  10 
configs.lr_factor = 0.1 
configs.gpu_idx = 0 
configs.global_weight = 0. 
configs.event_weight = 2. 
configs.local_weight = 1. 
configs.pretrained_path = "../checkpoints/ttnet_1st_phase/ttnet_1st_phase_epoch_30.pth"
configs.overwrite_global_2_local  = True
configs.freeze_global  = True
configs.smooth-labelling = True
configs.sigma =  1.0
configs.print_freq =  1000