In [1]:
from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
from torch.utils.data import Subset, DataLoader

import torchy

import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

cudnn.benchmark = True

import pandas as pd
import pickle

import os

In [2]:
# means and standard deviations ImageNet because the network is pretrained
means, stds = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)

# Define transforms to apply to each image
transf = transforms.Compose([ #transforms.Resize(227),      # Resizes short size of the PIL image to 256
                              transforms.CenterCrop(224),  # Crops a central square patch of the image 224 because torchvision's AlexNet needs a 224x224 input!
                              transforms.ToTensor(), # Turn PIL Image to torch.Tensor
                              transforms.Normalize(means,stds) # Normalizes tensor with mean and standard deviation
])

data_root = 'Data/PACS/'


datasets = {}
for name in os.listdir(data_root):
    
    if not name[0] == '.':
        datasets[name] = torchvision.datasets.ImageFolder(data_root+name, transform=transf)
        print(f"Added :{name}, length: {len(datasets[name])}, classes: {len(datasets[name].classes)}")
    
# Check dataset sizes
# print(f"Clipart Dataset: {len(clipart_dataset)}, classes: {len(clipart_dataset.classes)}")
# print(f"Quickdraw Dataset: {len(qd_dataset)}, classes: {len(qd_dataset.classes)}")
# # print(f"Cartoon Dataset: {len(cartoon_dataset)}")
# print(f"Sketch Dataset: {len(sketch_dataset)}, classes: {len(sketch_dataset.classes)}")

Added :cartoon, length: 2344, classes: 7
Added :art_painting, length: 2048, classes: 7
Added :photo, length: 1670, classes: 7
Added :sketch, length: 3929, classes: 7


In [3]:
BATCH_SIZE = 128

dataloaders = {}
for name in datasets.keys():
    dataloaders[name] = DataLoader(datasets[name], batch_size=BATCH_SIZE, shuffle=True, num_workers=0, drop_last=True)



In [5]:
from torchvision.models.resnet import resnet18
resnet18 = models.resnet18(pretrained=True)
num_ftrs = resnet18.fc.in_features
modules=list(resnet18.children())[:-1]
resnet18=nn.Sequential(*modules)
for p in resnet18.parameters():
    p.requires_grad = False
    


In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated



In [None]:
print(datasets['cartoon'].classes)

In [None]:
def extract_features(model, dataloaders, datasets):
    
    feature_dict = {}
    for domain in datasets.keys():
        print(domain)
        feature_dict[domain] = {}

        for i in range(len(datasets[domain].classes)):
            feature_dict[domain][i] = []

        for j, batch in enumerate(dataloaders[domain]):
            print(f"Batch: {j}")
            input, classes = batch
            
            features = model(input)
            for i, c in enumerate(classes):
                feature_dict[domain][int(c)].append(features[i])

    return feature_dict

    
    
    
feature_dict = extract_features(resnet18, dataloaders, datasets)
        

    

In [None]:
a_file = open("PACS.pkl", "wb")
pickle.dump(feature_dict, a_file)
a_file.close()

tmp = pd.DataFrame.from_dict(feature_dict)


In [None]:
print(len(feature_dict['photo'][2]))

