In [None]:
import pandas as pd
import numpy as np
import torchvision.models as models
import torch.nn as nn
import os
import torch
import torchvision.transforms as transforms
from torch.autograd import Variable
from PIL import Image

## Load images

In [None]:
img_path = './'
feature_path = '../features/'

In [None]:
domains = ['clipart', 'infograph','painting','quickdraw','real','sketch']
classes = []

In [None]:
with open(domains[0]+'_test.txt','r')as file:
    img_name, label = file.readline().split()
    print(file.readline().split())
    img = Image.open(img_name).convert('RGB')

## Extract Res-Net features

In [None]:
def res_feature(img,pretrained_model='resnet50',feature_layer='avgpool',use_gpu=False):
    model = None
    # Check image is not None
    if img is None:
        prin("No input image!")
        return
    # Get model
    if pretrained_model == 'resnet50':
        model = models.resnet50(pretrained=True)
    elif pretrained_model == 'resnet101':
        model = models.resnet101(pretrained=True)
    elif pretrained_model == 'resnet152':
        model = models.resnet152(pretrained=True)
    else:
        print("None model input")
        return
    # Get feature layer
    for param in model.parameters():
        param.requires_grad = False
    conv = nn.Sequential(*list(model.children())[:-1])
    conv.eval()
    
    # pre-processing image
    scaler = transforms.Resize((224,224))
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    to_tensor = transforms.ToTensor()
    
    img_norm = Variable(normalize(to_tensor(scaler(img))).unsqueeze(0))
    if use_gpu:
        img_norm = img_norm.cuda()
        conv = conv.cuda()
    
    feature = conv(img_norm)   
    feature = np.array(feature.cpu().detach().numpy()).reshape((1,-1))
    return feature

In [None]:
model_type='resnet50'
img_path = './'
feature_path = '../features/'+model_type+'/'
domains = domains = ['painting', 'real'] #'clipart', 'infograph','quickdraw','sketch']
use_gpu = torch.cuda.is_available()
for domain in domains:
    for s in ['train','test']:
        if domain != 'clipart' or s != 'train':
            continue
        feats = np.array([])
        labels = np.array([])
        print('Processing domain & set: ', domain,s)
        file_name = domain+'_'+s+'.txt'
        with open(file_name,'r') as file:
            i = 0
            for line in file.readlines():
                i = i+1
                try:
                    if len(line) < 10:
                        continue
                        print("Empty: ", line)
                    img_path, label = line.split()
                    if i %1000 == 0:
                        print('Extract class: ', label, ' Line: ', i)
                    img = Image.open(img_path).convert('RGB')
                    feature = res_feature(img,pretrained_model=model_type, use_gpu = use_gpu)
                    feats = np.append(feats, feature)
                    labels = np.append(labels, label)
                except:
                    print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Exception !!!!!!!!!!!!!!!!    ", label, '   Line', i)
        feats = feats.reshape((-1,2048))
        labels = labels.reshape((-1,1)).astype(int)
        # save to csv file
        print('Saving to ', domain+'_'+s+'_'+model_type+'_feats.csv ...')
        output_feats_file = feature_path + domain+'_'+s+'_'+model_type+'_feats.csv'
        output_labels_file = feature_path + domain+'_'+s+'_'+model_type+'_labels.csv'
        res_feats = pd.DataFrame(feats)
        res_labels = pd.DataFrame(labels)
        res_feats.to_csv(output_feats_file,index=False,header=False)
        res_labels.to_csv(output_labels_file,index=False,header=False)

In [None]:
print('Saving to ', domain+'_'+s+'_'+model_type+'_feats.csv ...')
output_feats_file = feature_path + domain+'_'+s+'_'+model_type+'_feats.csv'
output_labels_file = feature_path + domain+'_'+s+'_'+model_type+'_labels.csv'
res_feats = pd.DataFrame(feats)
res_labels = pd.DataFrame(labels)
res_feats.to_csv(output_feats_file,index=False,header=False)
res_labels.to_csv(output_labels_file,index=False,header=False)