In [95]:
import os
import argparse
import time
import csv
import datetime
import random
from path import Path
import numpy as np
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.model_zoo as model_zoo

import torchvision.models as models

import custom_transforms
from utils import tensor2array, save_checkpoint
from datasets.sequence_folders import SequenceFolder

from loss_functions import compute_smooth_loss, compute_photo_and_geometry_loss, compute_errors
from tensorboardX import SummaryWriter

import helper, data_helper

In [47]:
device = ("cuda:0" if torch.cuda.is_available() else "cpu")

In [29]:
class ResnetEncoder(nn.Module):
    """Pytorch module for a resnet encoder
    """
    def __init__(self, num_layers, pretrained, num_input_images=1):
        super(ResnetEncoder, self).__init__()

        self.num_ch_enc = np.array([64, 64, 128, 256, 512])

        resnets = {18: models.resnet18,
                   34: models.resnet34,
                   50: models.resnet50,
                   101: models.resnet101,
                   152: models.resnet152}

        if num_layers not in resnets:
            raise ValueError("{} is not a valid number of resnet layers".format(num_layers))

        if num_input_images > 1:
            self.encoder = resnet_multiimage_input(num_layers, pretrained, num_input_images)
        else:
            self.encoder = resnets[num_layers](pretrained)

        if num_layers > 34:
            self.num_ch_enc[1:] *= 4

    def forward(self, input_image):
        self.features = []
        x = input_image
        x = self.encoder.conv1(x)
        x = self.encoder.bn1(x)
        self.features.append(self.encoder.relu(x))
        self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1])))
        self.features.append(self.encoder.layer2(self.features[-1]))
        self.features.append(self.encoder.layer3(self.features[-1]))
        self.features.append(self.encoder.layer4(self.features[-1]))

        return self.features

In [30]:
class ResNetMultiImageInput(models.ResNet):
    """Constructs a resnet model with varying number of input images.
    Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
    """
    def __init__(self, block, layers, num_classes=1000, num_input_images=1):
        super(ResNetMultiImageInput, self).__init__(block, layers)
        self.inplanes = 64
        self.conv1 = nn.Conv2d(
            num_input_images * 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

def resnet_multiimage_input(num_layers, pretrained=False, num_input_images=1):
    """Constructs a ResNet model.
    Args:
        num_layers (int): Number of resnet layers. Must be 18 or 50
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        num_input_images (int): Number of frames stacked as input
    """
    assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet"
    blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers]
    block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[num_layers]
    model = ResNetMultiImageInput(block_type, blocks, num_input_images=num_input_images)

    if pretrained:
        loaded = model_zoo.load_url(models.resnet.model_urls['resnet{}'.format(num_layers)])
        loaded['conv1.weight'] = torch.cat(
            [loaded['conv1.weight']] * num_input_images, 1) / num_input_images
        model.load_state_dict(loaded)
    return model

In [31]:
class PoseDecoder(nn.Module):
    def __init__(self, num_ch_enc, num_input_features=1, num_frames_to_predict_for=1, stride=1):
        super(PoseDecoder, self).__init__()

        self.num_ch_enc = num_ch_enc
        self.num_input_features = num_input_features

        if num_frames_to_predict_for is None:
            num_frames_to_predict_for = num_input_features - 1
        self.num_frames_to_predict_for = num_frames_to_predict_for

        self.convs = OrderedDict()
        self.convs[("squeeze")] = nn.Conv2d(self.num_ch_enc[-1], 256, 1)
        self.convs[("pose", 0)] = nn.Conv2d(num_input_features * 256, 256, 3, stride, 1)
        self.convs[("pose", 1)] = nn.Conv2d(256, 256, 3, stride, 1)
        self.convs[("pose", 2)] = nn.Conv2d(256, 6 * num_frames_to_predict_for, 1)

        self.relu = nn.ReLU()

        self.net = nn.ModuleList(list(self.convs.values()))

    def forward(self, input_features):
        last_features = [f[-1] for f in input_features]

        cat_features = [self.relu(self.convs["squeeze"](f)) for f in last_features]
        cat_features = torch.cat(cat_features, 1)

        out = cat_features
        for i in range(3):
            out = self.convs[("pose", i)](out)
            if i != 2:
                out = self.relu(out)

        out = out.mean(3).mean(2)

        pose = 0.01 * out.view(-1, 6)

        return pose

In [32]:
class PoseResNet(nn.Module):

    def __init__(self, num_layers = 18, pretrained = True):
        super(PoseResNet, self).__init__()
        self.encoder = ResnetEncoder(num_layers = num_layers, pretrained = pretrained, num_input_images=2)
        self.decoder = PoseDecoder(self.encoder.num_ch_enc)

    def init_weights(self):
        pass

    def forward(self, img1, img2):
        x = torch.cat([img1,img2],1)
        features = self.encoder(x)
        pose = self.decoder([features])
        return pose

In [63]:
torch.backends.cudnn.benchmark = True

model = PoseResNet().to(device)
model.train()

tgt_img = torch.randn(4, 3, 256, 256).to(device)
ref_imgs = [torch.randn(4, 3, 256, 256).to(device) for i in range(2)]

pose = model(tgt_img, ref_imgs[0])

print(pose.size())

torch.Size([4, 6])


In [64]:
class ConvBlock(nn.Module):
    """Layer to perform a convolution followed by ELU
    """
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()

        self.conv = Conv3x3(in_channels, out_channels)
        self.nonlin = nn.ELU(inplace=True)

    def forward(self, x):
        out = self.conv(x)
        out = self.nonlin(out)
        return out

class Conv3x3(nn.Module):
    """Layer to pad and convolve input
    """
    def __init__(self, in_channels, out_channels, use_refl=True):
        super(Conv3x3, self).__init__()

        if use_refl:
            self.pad = nn.ReflectionPad2d(1)
        else:
            self.pad = nn.ZeroPad2d(1)
        self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3)

    def forward(self, x):
        out = self.pad(x)
        out = self.conv(out)
        return out

def upsample(x):
    """Upsample input tensor by a factor of 2
    """
    return F.interpolate(x, scale_factor=2, mode="nearest")

class DepthDecoder(nn.Module):
    def __init__(self, num_ch_enc, scales=range(4), num_output_channels=1, use_skips=True):
        super(DepthDecoder, self).__init__()

        self.alpha = 10
        self.beta = 0.01

        self.num_output_channels = num_output_channels
        self.use_skips = use_skips
        self.upsample_mode = 'nearest'
        self.scales = scales

        self.num_ch_enc = num_ch_enc
        self.num_ch_dec = np.array([16, 32, 64, 128, 256])

        # decoder
        self.convs = OrderedDict()
        for i in range(4, -1, -1):
            # upconv_0
            num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1]
            num_ch_out = self.num_ch_dec[i]
            self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out)

            # upconv_1
            num_ch_in = self.num_ch_dec[i]
            if self.use_skips and i > 0:
                num_ch_in += self.num_ch_enc[i - 1]
            num_ch_out = self.num_ch_dec[i]
            self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out)

        for s in self.scales:
            self.convs[("dispconv", s)] = Conv3x3(self.num_ch_dec[s], self.num_output_channels)

        self.decoder = nn.ModuleList(list(self.convs.values()))
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_features):
        self.outputs = []

        # decoder
        x = input_features[-1]
        for i in range(4, -1, -1):
            x = self.convs[("upconv", i, 0)](x)
            x = [upsample(x)]
            if self.use_skips and i > 0:
                x += [input_features[i - 1]]
            x = torch.cat(x, 1)
            x = self.convs[("upconv", i, 1)](x)
            if i in self.scales:
                self.outputs.append(self.alpha * self.sigmoid(self.convs[("dispconv", i)](x)) + self.beta)

        self.outputs = self.outputs[::-1]
        return self.outputs


class DispResNet(nn.Module):

    def __init__(self, num_layers = 18, pretrained = True):
        super(DispResNet, self).__init__()
        self.encoder = ResnetEncoder(num_layers = num_layers, pretrained = pretrained, num_input_images=1)
        self.decoder = DepthDecoder(self.encoder.num_ch_enc)

    def init_weights(self):
        pass

    def forward(self, x):
        features = self.encoder(x)
        outputs = self.decoder(features)
        
        if self.training:
            return outputs
        else:
            return outputs[0]

In [65]:
model = DispResNet().to(device)
model.train()

B = 12

tgt_img = torch.randn(B, 3, 256, 256).to(device)
ref_imgs = [torch.randn(B, 3, 256, 256).to(device) for i in range(2)]

tgt_depth = model(tgt_img)

print(tgt_depth[0].size())

torch.Size([12, 1, 256, 256])


In [66]:
random_seed = 0
random.seed(random_seed)
torch.manual_seed(random_seed)
np.random.seed(random_seed)
cudnn.deterministic = True
cudnn.benchmark = True

In [91]:
datadir = '/Users/alexandergao/Documents/DeepLearning/Competition/competition/data'
n_epochs = 200
batch_size = 1

lr = 0.0001
beta = 0.999
momentum = 0.9
best_error = -1
smooth_loss_weight = 0.1
photo_loss_weight = 1
geometry_consistency_weight = 0.5
iters = 0

resnet_layers = 50 #can also be set to 50
n_scales = 1
sequence_length = 3
ssim = True
mask = True
auto_mask = True
pretrain = True
padding_mode = 'zeros'

stats_update_rate = 10

# parser.add_argument('--weight-decay', '--wd', default=0, type=float, metavar='W', help='weight decay')
# parser.add_argument('--print-freq', default=10, type=int, metavar='N', help='print frequency')
# parser.add_argument('--log-summary', default='progress_log_summary.csv', metavar='PATH', help='csv where to save per-epoch train and valid stats')
# parser.add_argument('--log-full', default='progress_log_full.csv', metavar='PATH', help='csv where to save per-gradient descent train stats')
# parser.add_argument('--log-output', action='store_true', help='will log dispnet outputs at validation step')
# parser.add_argument('--num-scales', '--number-of-scales', type=int, help='the number of scales', metavar='W', default=1)
# parser.add_argument('-p', '--photo-loss-weight', type=float, help='weight for photometric loss', metavar='W', default=1)
# parser.add_argument('-s', '--smooth-loss-weight', type=float, help='weight for disparity smoothness loss', metavar='W', default=0.1)
# parser.add_argument('-c', '--geometry-consistency-weight', type=float, help='weight for depth consistency loss', metavar='W', default=0.5)
# parser.add_argument('--with-ssim', type=int, default=1, help='with ssim or not')
# parser.add_argument('--with-mask', type=int, default=1, help='with the the mask for moving objects and occlusions or not')
# parser.add_argument('--with-auto-mask', type=int,  default=0, help='with the the mask for stationary points')
# parser.add_argument('--with-pretrain', type=int,  default=1, help='with or without imagenet pretrain for resnet')
# parser.add_argument('--dataset', type=str, choices=['kitti', 'cs', 'nyu'], default='kitti', help='the dataset to train')
# parser.add_argument('--pretrained-disp', dest='pretrained_disp', default=None, metavar='PATH', help='path to pre-trained dispnet model')
# parser.add_argument('--pretrained-pose', dest='pretrained_pose', default=None, metavar='PATH', help='path to pre-trained Pose net model')
# parser.add_argument('--name', dest='name', type=str, required=True, help='name of the experiment, checkpoints are stored in checpoints/name')
# parser.add_argument('--padding-mode', type=str, choices=['zeros', 'border'], default='zeros',
#                     help='padding mode for image warping : this is important for photometric differenciation when going outside target image.'
#                          ' zeros will null gradients outside target image.'
#                          ' border will only null gradients of the coordinate outside (x or y)')
# parser.add_argument('--with-gt', action='store_true', help='use ground truth for validation. \
#                     You need to store it in npy 2D arrays see data/kitti_raw_loader.py for an example')

In [74]:
def train(args, train_loader, disp_net, pose_net, optimizer, train_writer):
    global photo_loss_weight, smooth_loss_weight, geometry_consistency_weight
    w1, w2, w3 = photo_loss_weight, smooth_loss_weight, geometry_consistency_weight

    # switch to train mode
    disp_net.train()
    pose_net.train()

    losses = []
    
    for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(trainloader):

        # measure data loading time
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics = intrinsics.to(device)

        # compute output
        tgt_depth, ref_depths = compute_depth(disp_net, tgt_img, ref_imgs)
        poses, poses_inv = compute_pose_with_inv(pose_net, tgt_img, ref_imgs)

        loss_1, loss_3 = compute_photo_and_geometry_loss(tgt_img, ref_imgs, intrinsics, tgt_depth, ref_depths,
                                                         poses, poses_inv, n_scales, ssim,
                                                         mask, auto_mask, padding_mode)

        loss_2 = compute_smooth_loss(tgt_depth, tgt_img, ref_depths, ref_imgs)

        loss = w1*loss_1 + w2*loss_2 + w3*loss_3

        if i % stats_update_rate == 0:
            print('photometric_error', loss_1.item(), n_iter)
            print('disparity_smoothness_loss', loss_2.item(), n_iter)
            print('geometry_consistency_loss', loss_3.item(), n_iter)
            print('total_loss', loss.item(), n_iter)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % args.print_freq == 0:
            print('Train: Time {} Data {} Loss {}'.format(batch_time, data_time, losses))

        iters += 1
        
        losses.append(loss.item())

    return losses.mean()

In [None]:
@torch.no_grad()
def validate_without_gt(args, val_loader, disp_net, pose_net, epoch, output_writers=[]):
    global device
    batch_time = AverageMeter()
    losses = AverageMeter(i=4, precision=4)
    log_outputs = len(output_writers) > 0

    # switch to evaluate mode
    disp_net.eval()
    pose_net.eval()

    end = time.time()
    for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(val_loader):
        tgt_img = tgt_img.to(device)
        ref_imgs = [img.to(device) for img in ref_imgs]
        intrinsics = intrinsics.to(device)
        intrinsics_inv = intrinsics_inv.to(device)

        # compute output
        tgt_depth = [1 / disp_net(tgt_img)]
        ref_depths = []
        for ref_img in ref_imgs:
            ref_depth = [1 / disp_net(ref_img)]
            ref_depths.append(ref_depth)

        if log_outputs and i < len(output_writers):
            if epoch == 0:
                output_writers[i].add_image('val Input', tensor2array(tgt_img[0]), 0)

            output_writers[i].add_image('val Dispnet Output Normalized',
                                        tensor2array(1/tgt_depth[0][0], max_value=None, colormap='magma'),
                                        epoch)
            output_writers[i].add_image('val Depth Output',
                                        tensor2array(tgt_depth[0][0], max_value=10),
                                        epoch)

        poses, poses_inv = compute_pose_with_inv(pose_net, tgt_img, ref_imgs)

        loss_1, loss_3 = compute_photo_and_geometry_loss(tgt_img, ref_imgs, intrinsics, tgt_depth, ref_depths,
                                                         poses, poses_inv, n_scales, ssim,
                                                         mask, False, padding_mode)

        loss_2 = compute_smooth_loss(tgt_depth, tgt_img, ref_depths, ref_imgs)

        loss_1 = loss_1.item()
        loss_2 = loss_2.item()
        loss_3 = loss_3.item()

        loss = loss_1
        losses.update([loss, loss_1, loss_2, loss_3])

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if i % args.print_freq == 0:
            print('valid: Time {} Loss {}'.format(batch_time, losses))

    return losses.avg, ['Total loss', 'Photo loss', 'Smooth loss', 'Consistency loss']


@torch.no_grad()
def validate_with_gt(args, val_loader, disp_net, epoch, output_writers=[]):
    global device
    batch_time = AverageMeter()
    error_names = ['abs_diff', 'abs_rel', 'sq_rel', 'a1', 'a2', 'a3']
    errors = AverageMeter(i=len(error_names))
    log_outputs = len(output_writers) > 0

    # switch to evaluate mode
    disp_net.eval()

    end = time.time()
    for i, (tgt_img, depth) in enumerate(val_loader):
        tgt_img = tgt_img.to(device)
        depth = depth.to(device)

        # check gt
        if depth.nelement() == 0:
            continue

        # compute output
        output_disp = disp_net(tgt_img)
        output_depth = 1/output_disp[:, 0]

        if log_outputs and i < len(output_writers):
            if epoch == 0:
                output_writers[i].add_image('val Input', tensor2array(tgt_img[0]), 0)
                depth_to_show = depth[0]
                output_writers[i].add_image('val target Depth',
                                            tensor2array(depth_to_show, max_value=10),
                                            epoch)
                depth_to_show[depth_to_show == 0] = 1000
                disp_to_show = (1/depth_to_show).clamp(0, 10)
                output_writers[i].add_image('val target Disparity Normalized',
                                            tensor2array(disp_to_show, max_value=None, colormap='magma'),
                                            epoch)

            output_writers[i].add_image('val Dispnet Output Normalized',
                                        tensor2array(output_disp[0], max_value=None, colormap='magma'),
                                        epoch)
            output_writers[i].add_image('val Depth Output',
                                        tensor2array(output_depth[0], max_value=10),
                                        epoch)

        if depth.nelement() != output_depth.nelement():
            b, h, w = depth.size()
            output_depth = torch.nn.functional.interpolate(output_depth.unsqueeze(1), [h, w]).squeeze(1)

        errors.update(compute_errors(depth, output_depth, args.dataset))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if i % args.print_freq == 0:
            print('valid: Time {} Abs Error {:.4f} ({:.4f})'.format(batch_time, errors.val[0], errors.avg[0]))
    return errors.avg, error_names

def compute_depth(disp_net, tgt_img, ref_imgs):
    tgt_depth = [1/disp for disp in disp_net(tgt_img)]

    ref_depths = []
    for ref_img in ref_imgs:
        ref_depth = [1/disp for disp in disp_net(ref_img)]
        ref_depths.append(ref_depth)

    return tgt_depth, ref_depths

def compute_pose_with_inv(pose_net, tgt_img, ref_imgs):
    poses = []
    poses_inv = []
    for ref_img in ref_imgs:
        poses.append(pose_net(tgt_img, ref_img))
        poses_inv.append(pose_net(ref_img, tgt_img))

    return poses, poses_inv

In [94]:
timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")
checkpoint_save_path = '/'.join(['training_checkpoints', save_path, timestamp])

print('=> will save everything to {}'.format(checkpoint_save_path))
os.makedirs(checkpoint_save_path, exist_ok=True)
    
output_writers = []

# Data loading code
normalize = custom_transforms.Normalize(mean=[0.45, 0.45, 0.45],
                                        std=[0.225, 0.225, 0.225])

train_transform = custom_transforms.Compose([custom_transforms.RandomHorizontalFlip(),
                                             custom_transforms.RandomScaleCrop(),
                                             custom_transforms.ArrayToTensor(),
                                             normalize])

valid_transform = custom_transforms.Compose([custom_transforms.ArrayToTensor(),
                                             normalize])

print("=> fetching scenes in '{}'".format(datadir))
train_set = SequenceFolder(datadir,
                           transform=train_transform,
                           seed=random_seed,
                           train=True,
                           sequence_length=sequence_length)

# if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping
# if args.with_gt:
#     from datasets.validation_folders import ValidationSet
#     val_set = ValidationSet(
#         args.data,
#         transform=valid_transform
#     )
# else:
val_set = SequenceFolder(datadir,
                         transform=valid_transform,
                         seed=random_seed,
                         train=False,
                         sequence_length=sequence_length)

print('{} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes)))
print('{} samples found in {} valid scenes'.format(len(val_set), len(val_set.scenes)))

trainloader = torch.utils.data.DataLoader(train_set,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          num_workers=2,
                                          pin_memory=True)

testloader = torch.utils.data.DataLoader(test_set,
                                         batch_size=batch_size,
                                         shuffle=False,
                                         num_workers=2,
                                         pin_memory=True)

disp_net = DispResNet(resnet_layers, pretrain).to(device)
pose_net = PoseResNet(18, pretrain).to(device)

# load parameters
if resume_training:
    print("=> using pre-trained weights for DispResNet")
    weights = torch.load(pretrained_disp)
    disp_net.load_state_dict(weights['state_dict'], strict=False)

    print("=> using pre-trained weights for PoseResNet")
    weights = torch.load(pretrained_pose)
    pose_net.load_state_dict(weights['state_dict'], strict=False)

disp_net = torch.nn.DataParallel(disp_net)
pose_net = torch.nn.DataParallel(pose_net)

optim_params = [{'params': disp_net.parameters(), 'lr': lr},
                {'params': pose_net.parameters(), 'lr': lr}]

optim = torch.optim.Adam(optim_params, betas=(momentum, beta), weight_decay=weight_decay)

=> will save everything to training_checkpoints/save/05-01-20:47
=> fetching scenes in '/Users/alexandergao/Documents/DeepLearning/Competition/competition/data'


FileNotFoundError: [Errno 2] No such file or directory: Path('/Users/alexandergao/Documents/DeepLearning/Competition/competition/data/train.txt')

 * Avg Loss : 3.003


In [72]:
##################################
########## TRAINING LOOP #########
##################################

for epoch in range(n_epochs):

    # train for one epoch
    train_loss = train(args, train_loader, disp_net, pose_net, optimizer, training_writer)
    print(' * Avg Loss : {:.3f}'.format(train_loss))

    # evaluate on validation set
    if args.with_gt:
        errors, error_names = validate_with_gt(args, val_loader, disp_net, epoch, output_writers)
    else:
        errors, error_names = validate_without_gt(args, val_loader, disp_net, pose_net, epoch, output_writers)
    error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(error_names, errors))
    print(' * Avg {}'.format(error_string))

    for error, name in zip(errors, error_names):
        training_writer.add_scalar(name, error, epoch)

    # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)
    decisive_error = errors[1]
    if best_error < 0:
        best_error = decisive_error

    # remember lowest error and save checkpoint
    is_best = decisive_error < best_error
    best_error = min(best_error, decisive_error)
    save_checkpoint(
        checkpoint_save_path, {
            'epoch': epoch + 1,
            'disp_net_state_dict': disp_net.module.state_dict(),
            'pose_net_state_dict': pose_net.module.state_dict(),
        },
        is_best)

NameError: name 'args' is not defined