In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from PIL import Image
# For our model
import torchvision
import torchvision.models as models
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import pickle

import os
import shutil
import pandas as pd
import numpy as np
from statistics import mean
import copy

from sklearn.metrics import precision_recall_fscore_support as score
from sklearn.metrics import *

from utils import *

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [7]:
def create_base_dir(dataset):
    return os.path.join('../dataset/', dataset)

def create_dir_if_absent(path):
    if not os.path.exists(path):
        os.makedirs(path)

def train_downstream_classfier(dataset, upstream_weight_type, batch_size, n_epochs, learning_rate, weight_decay, device):
    
    dataset_path = create_base_dir(dataset)
    train_path = os.path.join(dataset_path, 'Training')
    test_path = os.path.join(dataset_path, 'Test')

    train_loader, _train_data = create_dataloader(train_path, batch_size, True)
    test_loader, _test_data = create_dataloader(test_path, batch_size, False)
    
    target_class_num = len(os.listdir(init_test_path))
    print('target_class_num: ', target_class_num)
    
    print(_test_data.class_to_idx)

    
    if upstream_weight_type == 'naive':
        net = models.resnet50(pretrained=False)

    elif upstream_weight_type == 'imagenet':
        net = models.resnet50(pretrained=True)

    elif upstream_weight_type == 'byol':
        net = models.resnet50(pretrained=True)
        # TODO; import weight
        
    elif upstream_weight_type == 'rotation':
        net = models.resnet50(pretrained=True)
        # TODO; import weight
    
    net.fc = nn.Sequential(
        nn.Linear(
            net.fc.in_features,
            target_class_num
        ))

    net.to(device)
        
    best_model_weight = copy.deepcopy(net.state_dict())
    best_test_acc = 0
    best_init_dict = {}

    model_save_base = os.path.join('downstream_artifacts', dataset)
    create_dir_if_absent(model_save_base)

    for epoch in range(n_epochs):

        net, train_acc, train_prec, train_rec, train_f1 = train(train_loader, net, learning_rate, weight_decay, device)
        net, test_acc, test_prec, test_rec, test_f1 = test(test_loader, net, device)

        if test_acc > best_test_acc:

            best_init_dict['acc'] = test_acc
            best_init_dict['prec'] = test_prec
            best_init_dict['rec'] = test_rec
            best_init_dict['f1'] = test_f1

            best_test_acc = test_acc
            test_acc_str = '%.5f' % test_acc

            print('[Notification] Best Model Updated!')
            best_model_weight = copy.deepcopy(net.state_dict())

            model_save_path = os.path.join(model_save_base, 'classifier_acc_' + str(test_acc_str) + '.pth') 
            torch.save(net.state_dict(), model_save_path)

            with open(os.path.join(model_save_base, 'best_init_dict.pkl'), 'wb') as f:
                pickle.dump(best_init_dict, f)