In [27]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
#import torch.autograd.variable as Variable
from torchvision import transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import os
import argparse
import numpy as np

#SimCLR
from simclr import SimCLR
from simclr.modules import LogisticRegression, get_resnet
from simclr.modules.transformations import TransformsSimCLR

#ReLIC
#[TODO]
from relic import ReLIC
from relic.modules.transformations import TransformsRelic

# TensorBoard
#from torch.utils.tensorboard import SummaryWriter

from model import load_optimizer, save_model
from utils import yaml_config_hook

In [28]:
parser = argparse.ArgumentParser(description="SimCLR/ReLIC")
config = yaml_config_hook("./config/config.yaml")
for k, v in config.items():
    parser.add_argument(f"--{k}", default=v, type=type(v))

args = parser.parse_args(args=[])


# Master address for distributed data parallel
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "8000"

if not os.path.exists(args.model_path):
    os.makedirs(args.model_path)

args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
args.num_gpus = torch.cuda.device_count()
args.world_size = args.gpus * args.nodes

In [29]:
#PACS Dataset
NUM_CLASSES = 7      # 7 classes for each domain: 'dog', 'elephant', 'giraffe', 'guitar', 'horse', 'house', 'person'
DATASETS_NAMES = ['photo', 'art', 'cartoon', 'sketch']
CLASSES_NAMES = ['Dog', 'Elephant', 'Giraffe', 'Guitar', 'Horse', 'House', 'Person']
DIR_PHOTO = '../PRISM/datasets/PACS/photo'
DIR_ART = '../PRISM/datasets/PACS/art_painting'
DIR_CARTOON = '../PRISM/datasets/PACS/cartoon'
DIR_SKETCH = '../PRISM/datasets/PACS/sketch'

pacs_convertor= {'default':DIR_PHOTO, 'photo':DIR_PHOTO, 'art':DIR_ART, 'cartoon':DIR_CARTOON, 'sketch':DIR_SKETCH}
transform=transforms.Compose([transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))])
train_dataset= torchvision.datasets.ImageFolder(pacs_convertor['default'], transform=transform)
test_dataset = torchvision.datasets.ImageFolder(pacs_convertor['default'], transform=transform)
                
train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=64,
        shuffle=True,
        drop_last=True,
        num_workers=args.workers,
    )

test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=64,
        shuffle=False,
        drop_last=True,
        num_workers=args.workers,
    )

In [122]:
import Augmentor


In [123]:
img= Augmentor.Pipeline(DIR_PHOTO,output_directory= '/home/dongkyu/PRISM/datasets/PACS/photo_augmented')

Initialised with 1670 image(s) found.
Output directory set to /home/dongkyu/PRISM/datasets/PACS/photo_augmented.

In [124]:
img.random_distortion(probability=1.0, grid_width=8, grid_height=8, magnitude=2)
img.skew(probability=1.0, magnitude=0.2)
img.random_color(probability=1.0, min_factor=0.9, max_factor=1.0)
img.crop_random(probability=1.0, percentage_area=0.9, randomise_percentage_area=False)

In [125]:
img.sample(16700)

Processing <PIL.Image.Image image mode=RGB size=204x204 at 0x7FBB12466040>:  87%|████████▋ | 14483/16700 [01:15<00:00, 2744.02 Samples/s]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [126]:
1+1

2