In [1]:
import collections
import copy
import os

import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
from absl import app, flags
from skimage import io
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from tqdm import tqdm

task_type = 'training'
experiment_name = 'exp'
label_type = 'category'
learning_rate = 1e-4
weight_decay = 0
batch_size  = 256
epochs = 5
LABEL_SIZE = {'domain': 4, 'category': 7}

class PACSDataset(Dataset):

  def __init__(self,
               root_dir,
               label_type='domain',
               is_training=False,
               transform=None):
    self.root_dir = os.path.join(root_dir, 'train' if is_training else 'val')
    self.label_type = label_type
    self.is_training = is_training
    if transform:
      self.transform = transform
    else:
      self.transform = transforms.Compose([
          transforms.ToTensor(),
          transforms.Normalize(mean=[0.7659, 0.7463, 0.7173],
                               std=[0.3089, 0.3181, 0.3470]),
      ])

    self.dataset, self.label_list = self.initialize_dataset()
    self.label_to_id = {x: i for i, x in enumerate(self.label_list)}
    self.id_to_label = {i: x for i, x in enumerate(self.label_list)}

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, idx):
    image, label = self.dataset[idx]
    label_id = self.label_to_id[label]
    image = self.transform(image)
    return image, label_id

  def initialize_dataset(self):
    assert os.path.isdir(self.root_dir), \
        '`root_dir` is not found at %s' % self.root_dir

    dataset = []
    domain_set = set()
    category_set = set()
    cnt = 0

    for root, dirs, files in os.walk(self.root_dir, topdown=True):
      if files:
        parts = root.split(os.sep)
        if len(parts) >= 3:
          domain = parts[-2]
          category = parts[-1]
          domain_set.add(domain)
          category_set.add(category)
          pbar = tqdm(files)
          for name in pbar:
            pbar.set_description('Processing Folder: domain=%s, category=%s' %
                                (domain, category))
            img_array = io.imread(os.path.join(root, name))
            dataset.append((img_array, domain, category))
    # use this below if you are running on google colab
    # for root, dirs, files in os.walk(self.root_dir, topdown=True):
    #   if files:
    #     _, domain, category = root.rsplit('/', maxsplit=2)
    #     domain_set.add(domain)
    #     category_set.add(category)
    #     pbar = tqdm(files)
    #     for name in pbar:
    #       pbar.set_description('Processing Folder: domain=%s, category=%s' %
    #                            (domain, category))
    #       img_array = io.imread(os.path.join(root, name))
    #       dataset.append((img_array, domain, category))
    images, domains, categories = zip(*dataset)

    if self.label_type == 'domain':
      labels = sorted(domain_set)
      dataset = list(zip(images, domains))
    elif self.label_type == 'category':
      labels = sorted(category_set)
      dataset = list(zip(images, categories))
    else:
      raise ValueError(
          'Unknown `label_type`: Expecting `domain` or `category`.')

    return dataset, labels


class AlexNetLargeKernel(nn.Module):

  def __init__(self, configs):
    super().__init__()
    dropout = configs['dropout']
    num_classes = configs['num_classes']

    self.features = nn.Sequential(
      nn.Conv2d(3, 96, kernel_size=21, stride=8, padding=1),
      nn.ReLU(inplace=True),
      nn.Conv2d(96, 256, kernel_size=7, stride=2,padding=2),
      nn.ReLU(inplace=True),
      nn.Conv2d(256, 384, kernel_size=3, padding=1),
      nn.ReLU(inplace=True),
      nn.Conv2d(384, 384, kernel_size=3, padding=1),
      nn.ReLU(inplace=True),
      nn.Conv2d(384, 256, kernel_size=3, stride=2),
      nn.ReLU(inplace=True),
    )
    self.classifier = nn.Sequential(
      nn.Flatten(),
      nn.Dropout(dropout),
      nn.Linear(9216, 4096),
      nn.ReLU(inplace=True),
      nn.Dropout(dropout),
      nn.Linear(4096, 4096),
      nn.ReLU(inplace=True),
      nn.Linear(4096, num_classes)
    )

  def forward(self, x):
    x = self.features(x)
    x = self.classifier(x)
    return x

def model_training():

  best_model = None
  best_acc = 0.0

  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

  expt_name = 'experiments/{}/{}_lr_{}.wd_{}'.format(experiment_name, label_type, learning_rate, weight_decay)

  os.makedirs(expt_name, exist_ok=True)
  writer = SummaryWriter(log_dir=expt_name)

  configs = {'num_classes': LABEL_SIZE[label_type], 'dropout': 0.5}

  model = AlexNetLargeKernel(configs).to(device)

  print('Model Architecture:\n%s' % model)

  criterion = nn.CrossEntropyLoss(reduction='mean')
  optimizer = torch.optim.Adam(model.parameters(),
                               lr=learning_rate,
                               weight_decay=weight_decay)

  try:
    for epoch in range(epochs):
      print(epoch)
      for phase in ('train', 'eval'):
        if phase == 'train':
          model.train()
          dataset = train_dataset
          data_loader = train_loader
        else:
          model.eval()
          dataset = val_dataset
          data_loader = val_loader

        running_loss = 0.0
        running_corrects = 0

        for step, (images, labels) in enumerate(data_loader):
          print(step)
          images = images.to(device)
          labels = labels.to(device)

          optimizer.zero_grad()

          with torch.set_grad_enabled(phase == 'train'):
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)

            if phase == 'train':
              loss.backward()
              optimizer.step()

              writer.add_scalar('Loss/{}'.format(phase), loss.item(),
                                epoch * len(data_loader) + step)

          running_loss += loss.item() * images.size(0)
          running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / len(dataset)
        epoch_acc = running_corrects.double() / len(dataset)
        writer.add_scalar('Epoch_Loss/{}'.format(phase), epoch_loss, epoch)
        writer.add_scalar('Epoch_Accuracy/{}'.format(phase), epoch_acc, epoch)
        print('[Epoch %d] %s accuracy: %.4f, loss: %.4f' %
              (epoch + 1, phase, epoch_acc, epoch_loss))

        if phase == 'eval':
          if epoch_acc > best_acc:
            best_acc = epoch_acc
            best_model = copy.deepcopy(model.state_dict())
            torch.save(best_model, os.path.join(expt_name, 'best_model.pt'))

  except KeyboardInterrupt:
    pass

  return




In [2]:
  train_dataset = PACSDataset(root_dir='pacs_dataset',
                              label_type=label_type,
                              is_training=True)
  train_loader = DataLoader(train_dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=0)

  val_dataset = PACSDataset(root_dir='pacs_dataset',
                            label_type=label_type,
                            is_training=False)
  val_loader = DataLoader(val_dataset,
                          batch_size=batch_size,
                          shuffle=False,
                          num_workers=0)

Processing Folder: domain=art_painting, category=dog:   0%|          | 0/348 [00:00<?, ?it/s]

Processing Folder: domain=art_painting, category=dog: 100%|██████████| 348/348 [00:03<00:00, 98.61it/s] 
Processing Folder: domain=art_painting, category=elephant: 100%|██████████| 227/227 [00:02<00:00, 100.82it/s]
Processing Folder: domain=art_painting, category=giraffe: 100%|██████████| 254/254 [00:02<00:00, 98.99it/s] 
Processing Folder: domain=art_painting, category=guitar: 100%|██████████| 169/169 [00:01<00:00, 98.01it/s] 
Processing Folder: domain=art_painting, category=horse: 100%|██████████| 179/179 [00:01<00:00, 94.67it/s] 
Processing Folder: domain=art_painting, category=house: 100%|██████████| 262/262 [00:02<00:00, 95.19it/s] 
Processing Folder: domain=art_painting, category=person: 100%|██████████| 404/404 [00:04<00:00, 94.33it/s] 
Processing Folder: domain=cartoon, category=dog: 100%|██████████| 343/343 [00:03<00:00, 98.90it/s] 
Processing Folder: domain=cartoon, category=elephant: 100%|██████████| 411/411 [00:04<00:00, 96.24it/s] 
Processing Folder: domain=cartoon, catego

In [3]:
model_training()

Model Architecture:
AlexNetLargeKernel(
  (features): Sequential(
    (0): Conv2d(3, 96, kernel_size=(21, 21), stride=(8, 8), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(96, 256, kernel_size=(7, 7), stride=(2, 2), padding=(2, 2))
    (3): ReLU(inplace=True)
    (4): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ReLU(inplace=True)
    (6): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(2, 2))
    (9): ReLU(inplace=True)
  )
  (classifier): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Dropout(p=0.5, inplace=False)
    (2): Linear(in_features=9216, out_features=4096, bias=True)
    (3): ReLU(inplace=True)
    (4): Dropout(p=0.5, inplace=False)
    (5): Linear(in_features=4096, out_features=4096, bias=True)
    (6): ReLU(inplace=True)
    (7): Linear(in_features=4096, out_features=7, bias=True)
  )
)
0
0
1
2
3
4
5
6
7
8
9
10

Category_Label:
[Epoch 5] train accuracy: 0.5625, loss: 1.1556
[Epoch 5] eval accuracy: 0.5778, loss: 1.1296

Domain_Label:
[Epoch 5] train accuracy: 0.8475, loss: 0.3743
[Epoch 5] eval accuracy: 0.8133, loss: 0.4364