In [1]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
from collections import defaultdict
from utils import *
from model_def import *

reference: https://github.com/pytorch/examples/blob/master/imagenet/main.py

In [2]:
cudnn.benchmark = True
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
args = defaultdict(lambda:False, {'print_freq': 250, 'lr':.03,'arch':'resnet18'} )

In [4]:
data_loader_sup_train, data_loader_sup_val, data_loader_unsup = image_loader('data/ssl_data_96',32)

In [5]:
def train(sup_loader, unsup_loader, model, criterion, optimizer, epoch, args):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(len(sup_loader), batch_time, data_time, losses, top1,
                             top5, prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i, ((input_sup, target_sup), (input_unsup,target_unsup)) in enumerate(zip(sup_loader,unsup_loader)):
        # measure data loading time
        data_time.update(time.time() - end)

        input_sup = input_sup.to(device)
        target_sup = target_sup.to(device)
        input_unsup = input_sup.to(device)
        target_unsup = target_sup.to(device)
        

        # compute output
        output_sup = model(input_sup)
        #output_unsup = model(input_unsup)
        loss = criterion(output_sup, target_sup)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output_sup, target_sup, topk=(1, 5))
        losses.update(loss.item(), input_sup.size(0))
        top1.update(acc1[0], input_sup.size(0))
        top5.update(acc5[0], input_sup.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args['print_freq'] == 0:
            progress.print(i)


In [6]:
def validate(val_loader, model, criterion, args):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(len(val_loader), batch_time, losses, top1, top5,
                             prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (input_val, target_val) in enumerate(val_loader):
            input_val = input_val.to(device)
            target_val = target_val.to(device)

            # compute output
            output_val = model(input_val)
            loss = criterion(output_val, target_val)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output_val, target_val, topk=(1, 5))
            losses.update(loss.item(), input_val.size(0))
            top1.update(acc1[0], input_val.size(0))
            top5.update(acc5[0], input_val.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args['print_freq'] == 0:
                progress.print(i)

        # TODO: this should also be done with the ProgressMeter
        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
              .format(top1=top1, top5=top5))

    return top1.avg


In [7]:
global best_acc1
# create model
model = resnet34() if args['arch']=='resnet32' else resnet18()
model = model.to(device)
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'])
# torch.optim.SGD(model.parameters(), args['lr'],
#                             momentum=args.momentum,
#                             weight_decay=args.weight_decay)

# optionally resume from a checkpoint
if args['resume']:
    if os.path.isfile(args['resume']):
        print("=> loading checkpoint '{}'".format(args['resume']))
        checkpoint = torch.load(args['resume'])
        args['start_epoch'] = checkpoint['epoch']
        best_acc1 = checkpoint['best_acc1']
        
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})"
              .format(args.resume, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(args['resume']))
else:
    best_acc1 = -1

for epoch in range(0, 3):
    #adjust_learning_rate(optimizer, epoch, args)
    # train for one epoch
    train(data_loader_sup_train, data_loader_unsup, model, criterion, optimizer, epoch, args)

    # evaluate on validation set
    acc1 = validate(data_loader_sup_val, model, criterion, args)

    # remember best acc@1 and save checkpoint
    is_best = acc1 > best_acc1
    best_acc1 = max(acc1, best_acc1)

    save_checkpoint({
        'epoch': epoch + 1,
        'arch': args.arch,
        'state_dict': model.state_dict(),
        'best_acc1': best_acc1,
        'optimizer' : optimizer.state_dict(),
    }, is_best)




Epoch: [0][   0/2000]	Time  1.715 ( 1.715)	Data  0.109 ( 0.109)	Loss 8.6142e+00 (8.6142e+00)	Acc@1   0.00 (  0.00)	Acc@5   0.00 (  0.00)
Epoch: [0][ 250/2000]	Time  0.103 ( 0.114)	Data  0.002 ( 0.003)	Loss 7.0099e+00 (8.6432e+00)	Acc@1   0.00 (  0.02)	Acc@5   0.00 (  0.35)
Epoch: [0][ 500/2000]	Time  0.111 ( 0.111)	Data  0.003 ( 0.003)	Loss 6.9405e+00 (7.8106e+00)	Acc@1   0.00 (  0.06)	Acc@5   0.00 (  0.40)
Epoch: [0][ 750/2000]	Time  0.112 ( 0.111)	Data  0.003 ( 0.002)	Loss 6.9734e+00 (7.5282e+00)	Acc@1   0.00 (  0.08)	Acc@5   0.00 (  0.40)
Epoch: [0][1000/2000]	Time  0.100 ( 0.110)	Data  0.003 ( 0.002)	Loss 6.9293e+00 (7.3853e+00)	Acc@1   0.00 (  0.07)	Acc@5   0.00 (  0.39)
Epoch: [0][1250/2000]	Time  0.114 ( 0.110)	Data  0.003 ( 0.002)	Loss 6.9059e+00 (7.3005e+00)	Acc@1   0.00 (  0.07)	Acc@5   0.00 (  0.39)
Epoch: [0][1500/2000]	Time  0.110 ( 0.110)	Data  0.003 ( 0.003)	Loss 7.0110e+00 (7.2440e+00)	Acc@1   0.00 (  0.08)	Acc@5   0.00 (  0.41)
Epoch: [0][1750/2000]	Time  0.116 ( 0.110

AttributeError: 'collections.defaultdict' object has no attribute 'arch'