In [1]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, TensorDataset, random_split
import torchvision
from torchvision import transforms
from torchvision.datasets import MNIST, USPS, SVHN, CIFAR10, STL10

import tensorflow_datasets as tfds

import os
import pickle
import numpy as np
from scipy.io import loadmat
import PIL
from PIL import Image

from tools.autoaugment import SVHNPolicy, CIFAR10Policy
from tools.randaugment import RandAugment

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
HOME = os.environ['HOME']

In [3]:
def load_mnist(split='train', translate=None, twox=False, ntr=None, autoaug=None, channels=3):
    '''
        autoaug == 'AA', AutoAugment
                   'FastAA', Fast AutoAugment
                   'RA', RandAugment
        channels == 3 return by default rgb 3 channel image
                    1 Return a single channel image
    '''
    path = f'data/mnist-{split}.pkl'
    if not os.path.exists(path):
        dataset = MNIST(f'{HOME}/.pytorch/MNIST', train=(split=='train'), download=True)
        x, y = dataset.data, dataset.targets
        if split=='train':
            #[TODO]- Why should we only use 10k images?
            #x, y = x, y
            x, y = x[0:10000], y[0:10000]
        x = torch.tensor(resize_imgs(x.numpy(), 32))
        x = (x.float()/255.).unsqueeze(1).repeat(1,3,1,1)
        with open(path, 'wb') as f:
            pickle.dump([x, y], f)
    with open(path, 'rb') as f:
        x, y = pickle.load(f)
        if channels == 1:
            x = x[:,0:1,:,:]
    
    if ntr is not None:
        x, y = x[0:ntr], y[0:ntr]
    
    # Without Data Augmentation
    if (translate is None) and (autoaug is None):
        dataset = TensorDataset(x, y)
        return dataset
    
    # Data Augmentation Pipeline
    transform = [transforms.ToPILImage()]
    if translate is not None:
        transform.append(transforms.RandomAffine(0, [translate, translate]))
    if autoaug is not None:
        if autoaug == 'AA':
            transform.append(SVHNPolicy())
        elif autoaug == 'RA':
            transform.append(RandAugment(3,4))
    transform.append(transforms.ToTensor())
    transform = transforms.Compose(transform)
    dataset = myTensorDataset(x, y, transform=transform, twox=twox)
    return dataset

In [5]:
def load_cifar10(split='train', translate=None, twox=False, ntr=None, autoaug=None, channels=3):
    '''
        autoaug == 'AA', AutoAugment
                   'FastAA', Fast AutoAugment
                   'RA', RandAugment
        channels == 3 return by default rgb 3 channel image
                    1 Return a single channel image
    '''
    path = f'data/cifar10-{split}.pkl'
    cifar10_transforms_train= transforms.Compose([transforms.Resize((32,32))]) #224,224
    if not os.path.exists(path):
        dataset = CIFAR10(f'{HOME}/.pytorch/CIFAR10', train=(split=='train'), download=True, transform= cifar10_transforms_train)
        x, y = dataset.data, dataset.targets
        
        #Only Select First 10k images as train
        #if split=='train':
        #    x, y = x[0:10000], y[0:10000]
        
        #[TODO] - solve -> AttributeError: 'numpy.ndarray' object has no attribute 'numpy'
        #x = torch.tensor(resize_imgs(x.numpy(), 32))
        #x = torch.tensor(resize_imgs_dkcho(x, 32)) # x-> torch.Size([10000, 32, 32, 3])
        x= torch.tensor(x)
        x = (x.float()/255.)#.unsqueeze(1).repeat(1,3,1,1)  #<class 'torch.Tensor'>
        x= x.permute(0,3,1,2) #[batchsize,w,h,channel] -> [batchsize, channel, w,h]
        y = torch.tensor(y)
        with open(path, 'wb') as f:
            pickle.dump([x, y], f)
    with open(path, 'rb') as f:
        x, y = pickle.load(f)
        if channels == 1:
            x = x[:,0:1,:,:]
    
    if ntr is not None:
        x, y = x[0:ntr], y[0:ntr]
    
    # Without Data Augmentation
    if (translate is None) and (autoaug is None):
        dataset = TensorDataset(x, y)
        return dataset
    
    # Data Augmentation Pipeline
    transform = [transforms.ToPILImage()]
    if translate is not None:
        transform.append(transforms.RandomAffine(0, [translate, translate]))
    if autoaug is not None:
        if autoaug == 'AA':
            transform.append(CIFAR10Policy()) #originally SVHNPolicy()
        elif autoaug == 'RA':
            transform.append(RandAugment(3,4))
    transform.append(transforms.ToTensor())
    transform = transforms.Compose(transform)
    dataset = myTensorDataset(x, y, transform=transform, twox=twox)
    return dataset




def load_usps(split='train', channels=3):
    path = f'data/usps-{split}.pkl'
    if not os.path.exists(path):
        dataset = USPS(f'{HOME}/.pytorch/USPS', train=(split=='train'), download=True)
        x, y = dataset.data, dataset.targets
        x = torch.tensor(resize_imgs(x, 32))
        x = (x.float()/255.).unsqueeze(1).repeat(1,3,1,1)
        y = torch.tensor(y)
        with open(path, 'wb') as f:
            pickle.dump([x, y], f)
    with open(path, 'rb') as f:
        x, y = pickle.load(f)
        if channels == 1:
            x = x[:,0:1,:,:]
    dataset = TensorDataset(x, y)
    return dataset

In [6]:
def load_cifar10c(split='train', translate=None, twox=False, ntr=None, autoaug=None, channels=3):
    path = f'data/cifar10c-{split}.pkl'
    cifar10_transforms_train= transforms.Compose([transforms.Resize((32,32))]) #224,224
    if not os.path.exists(path):
        dataset = tfds.as_numpy(tfds.load('cifar10_corrupted', split= split, shuffle_files= True, batch_size= -1))
        x, y = dataset['image'], dataset['label']
        #dataset = CIFAR10(f'{HOME}/.pytorch/CIFAR10', train=(split=='train'), download=True, transform= cifar10_transforms_train)
        #x, y = dataset.data, dataset.targets
        
        #Only Select First 10k images as train
        #if split=='train':
        #    x, y = x[0:10000], y[0:10000]
        
        #[TODO] - solve -> AttributeError: 'numpy.ndarray' object has no attribute 'numpy'
        #x = torch.tensor(resize_imgs(x.numpy(), 32))
        #x = torch.tensor(resize_imgs_dkcho(x, 32)) # x-> torch.Size([10000, 32, 32, 3])
        x= torch.tensor(x)
        x = (x.float()/255.)#.unsqueeze(1).repeat(1,3,1,1)  #<class 'torch.Tensor'>
        x= x.permute(0,3,1,2) #[batchsize,w,h,channel] -> [batchsize, channel, w,h]
        y = torch.tensor(y)
        with open(path, 'wb') as f:
            pickle.dump([x, y], f)
    with open(path, 'rb') as f:
        x, y = pickle.load(f)
        if channels == 1:
            x = x[:,0:1,:,:]
    
    if ntr is not None:
        x, y = x[0:ntr], y[0:ntr]
    
    # Without Data Augmentation
    if (translate is None) and (autoaug is None):
        dataset = TensorDataset(x, y)
        return dataset

In [7]:
def load_svhn(split='train', channels=3):
    dataset = SVHN(f'{HOME}/.pytorch/SVHN', split=split, download=True)
    x, y = dataset.data, dataset.labels
    x = x.astype('float32')/255.
    x, y = torch.tensor(x), torch.tensor(y)
    if channels == 1:
        x = x.mean(1, keepdim=True)
    dataset = TensorDataset(x, y)
    return dataset

In [8]:
def load_pacs(split='train', translate=None, twox=False, ntr=None, autoaug=None, channels=3):
    #PACS Dataset
    NUM_CLASSES = 7      # 7 classes for each domain: 'dog', 'elephant', 'giraffe', 'guitar', 'horse', 'house', 'person'
    DATASETS_NAMES = ['photo', 'art', 'cartoon', 'sketch']
    CLASSES_NAMES = ['Dog', 'Elephant', 'Giraffe', 'Guitar', 'Horse', 'House', 'Person']
    DIR_PHOTO = './data/PACS/photo'
    DIR_ART = './data/PACS/art_painting'
    DIR_CARTOON = './data/PACS/cartoon'
    DIR_SKETCH = './data/PACS/sketch'

    path = f'data/pacs-{split}.pkl'

    pacs_convertor= {'train':DIR_PHOTO, 'photo':DIR_PHOTO, 'art':DIR_ART, 'cartoon':DIR_CARTOON, 'sketch':DIR_SKETCH}
    
    pacs_transforms_train= transforms.Compose([transforms.ToTensor(),transforms.Resize((224,224))]) #224,224
    if not os.path.exists(path):
        dataset= torchvision.datasets.ImageFolder(pacs_convertor[split], transform=pacs_transforms_train)
        train_loader = torch.utils.data.DataLoader(dataset,batch_size=len(dataset),drop_last=True)
        x, y= next(iter(train_loader))
        #x, y = dataset.image, dataset.label
        x= torch.tensor(x)
        #x = (x.float()/255.)#.unsqueeze(1).repeat(1,3,1,1)  #<class 'torch.Tensor'>
        #x= x.permute(0,3,1,2) #[batchsize,w,h,channel] -> [batchsize, channel, w,h]
        y = torch.tensor(y)
        with open(path, 'wb') as f:
            pickle.dump([x, y], f)
    with open(path, 'rb') as f:
        x, y = pickle.load(f)
        if channels == 1:
            x = x[:,0:1,:,:]
    
    if ntr is not None:
        x, y = x[0:ntr], y[0:ntr]
    
    # Without Data Augmentation
    if (translate is None) and (autoaug is None):
        dataset = TensorDataset(x, y)
        return dataset

In [9]:
def load_mnist_m(split='train', channels=3):
    path = f'data/mnist_m-{split}.pkl'
    with open(path, 'rb') as f:
        x, y = pickle.load(f)
        x, y = torch.tensor(x.astype('float32')/255.), torch.tensor(y)
        if channels==1:
            x = x.mean(1, keepdim=True)
    dataset = TensorDataset(x, y)
    return dataset

In [10]:
def load_cifar10_1(split='train', translate=None, twox=False, ntr=None, autoaug=None, channels=3):
    path = f'data/cifar10_1-{split}.pkl'
    cifar10_transforms_train= transforms.Compose([transforms.Resize((32,32))]) #224,224
    if not os.path.exists(path):
        dataset = tfds.as_numpy(tfds.load('cifar10_1', split= split, shuffle_files= True, batch_size= -1))
        x, y = dataset['image'], dataset['label']
        #dataset = CIFAR10(f'{HOME}/.pytorch/CIFAR10', train=(split=='train'), download=True, transform= cifar10_transforms_train)
        #x, y = dataset.data, dataset.targets
        
        #Only Select First 10k images as train
        #if split=='train':
        #    x, y = x[0:10000], y[0:10000]
        
        #[TODO] - solve -> AttributeError: 'numpy.ndarray' object has no attribute 'numpy'
        #x = torch.tensor(resize_imgs(x.numpy(), 32))
        #x = torch.tensor(resize_imgs_dkcho(x, 32)) # x-> torch.Size([10000, 32, 32, 3])
        x= torch.tensor(x)
        x = (x.float()/255.)#.unsqueeze(1).repeat(1,3,1,1)  #<class 'torch.Tensor'>
        x= x.permute(0,3,1,2) #[batchsize,w,h,channel] -> [batchsize, channel, w,h]
        y = torch.tensor(y)
        with open(path, 'wb') as f:
            pickle.dump([x, y], f)
    with open(path, 'rb') as f:
        x, y = pickle.load(f)
        if channels == 1:
            x = x[:,0:1,:,:]
    
    if ntr is not None:
        x, y = x[0:ntr], y[0:ntr]
    
    # Without Data Augmentation
    if (translate is None) and (autoaug is None):
        dataset = TensorDataset(x, y)
        return dataset

In [11]:
def load_cifar10c_level(split='test', ctype= 'fog', level= 5, translate=None, twox=False, ntr=None, autoaug=None, channels=3):
    path = f'data/cifar10c-{ctype}_{level}.pkl'
    cifar10_transforms_train= transforms.Compose([transforms.Resize((32,32))]) #224,224
    if not os.path.exists(path):
        tfpath= f'cifar10_corrupted/{ctype}_{level}'.format(ctype= ctype, level= level)
        dataset = tfds.as_numpy(tfds.load(tfpath, split= split, shuffle_files= True, batch_size= -1))
        x, y = dataset['image'], dataset['label']
        x= torch.tensor(x)
        x = (x.float()/255.)#.unsqueeze(1).repeat(1,3,1,1)  #<class 'torch.Tensor'>
        print(x.shape)
        x= x.permute(0,3,1,2) #[batchsize,w,h,channel] -> [batchsize, channel, w,h]
        y = torch.tensor(y)
        with open(path, 'wb') as f:
            pickle.dump([x, y], f)
    with open(path, 'rb') as f:
        x, y = pickle.load(f)
        if channels == 1:
            x = x[:,0:1,:,:]
    
    if ntr is not None:
        x, y = x[0:ntr], y[0:ntr]
    
    # Without Data Augmentation
    if (translate is None) and (autoaug is None):
        dataset = TensorDataset(x, y)
        return dataset

In [14]:
dataset= load_cifar10('train')

In [34]:
train_loader = torch.utils.data.DataLoader(dataset,batch_size=128,drop_last=True, shuffle=True)
x, y= next(iter(train_loader))

In [15]:
len(dataset)

50000

# CIFAR10c

In [32]:
ds= tfds.load('cifar10_corrupted/fog_5', split= 'test', shuffle_files= True, batch_size= -1)

In [14]:
corruption_type= ['fog','snow','frost','zoom_blur','defocus_blur','frosted_glass_blur','speckle_noise',
                      'shot_noise','impulse_noise','jpeg_compression','pixelate','spatter']

In [23]:
for c_type in corruption_type:
    st= 'cifar10_corrupted/{c_type}_{lv}'.format(c_type= c_type, lv=5)
    ds= tfds.load(st, split= 'test', shuffle_files= True, batch_size= -1)
    print(len(ds['image']))

10000
10000
10000
10000
10000
10000
10000
10000
10000
10000
10000
10000


In [16]:
for c_type in corruption_type:
    df= load_cifar10c_level(split='test', ctype= c_type, level= 5)
    print(df[0][0][0][0])
    print("===")

AttributeError: 'tuple' object has no attribute 'shape'

# MNIST-M

In [11]:
dataset= load_mnist_m('test')

# SVHN

In [None]:
dataset= load_svhn('test')

Using downloaded and verified file: /home/dongkyu/.pytorch/SVHN/test_32x32.mat


# PACS

In [15]:
dataset= load_pacs(split='photo')
train_size= len(dataset)
train_loader = torch.utils.data.DataLoader(dataset,batch_size=train_size,drop_last=True, shuffle=True)

x, y= next(iter(train_loader))

In [17]:
len(dataset)

1670

In [145]:
y= y.tolist()
from collections import Counter
c=Counter(y)
print(c)

Counter({4: 816, 0: 772, 2: 753, 1: 740, 3: 608, 6: 160, 5: 80})


In [146]:
sample,samplelabel= x[0],y[0]

In [147]:
samplelabel

5

In [148]:
topil= transforms.ToPILImage()
image= topil(sample)
image.save('./data/image_test.png')

# Check STL10

In [189]:
def load_stl10(split='train', channels=3):
    STL10_transforms_train= transforms.Compose([transforms.Resize((32,32))])
    dataset = STL10(f'{HOME}/.pytorch/STL10', split=split, download=True, transform= STL10_transforms_train)
    x, y = dataset.data, dataset.labels
    x = x.astype('float32')/255.
    x, y = torch.tensor(x), torch.tensor(y)
    if channels == 1:
        x = x.mean(1, keepdim=True)
    dataset = TensorDataset(x, y)
    return dataset

In [190]:
dataset= load_stl10('train')
train_loader = torch.utils.data.DataLoader(dataset,batch_size=1,drop_last=True)
x, y= next(iter(train_loader))
x= x[-1]

Files already downloaded and verified


In [192]:
topil= transforms.ToPILImage()
image= topil(x)
image.save('image_test.png')

# Check CIFAR10

In [196]:
dataset= load_cifar10('train')
train_loader = torch.utils.data.DataLoader(dataset,batch_size=2,drop_last=True)
x, y= next(iter(train_loader))
x= x[-1]

In [197]:
topil= transforms.ToPILImage()
image= topil(x)
image.save('./data/image_test.png')

# Wide Resnet

In [1]:
import torchvision


  from .autonotebook import tqdm as notebook_tqdm


In [6]:
torchvision.models.wide_resnet50_2(weights= torchvision.models.resnet.Wide_ResNet50_2_Weights)

Downloading: "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth" to /home/dongkyu/.cache/torch/hub/checkpoints/wide_resnet50_2-95faca4d.pth
100%|██████████| 132M/132M [00:06<00:00, 20.0MB/s] 


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), strid

# Con_LOSS

In [12]:
a= torch.randn(128,128)
b = torch.randn(128,128)
c = torch.randn(128,128)
d = torch.randn(128,128)
features= [a,b,c,d]

In [15]:
batch_size= 128

In [16]:
def off_diagonal(x):
    # return a flattened view of the off-diagonal elements of a square matrix
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

In [17]:
total_loss= 0.0
for p, anchor_feature in enumerate(features):
    for q, contrast_feature in enumerate(features):
        if p != q:
            anchor_feature= (anchor_feature - anchor_feature.mean(0)) / anchor_feature.std(0) #torch.Size([256, 128])
            contrast_feature = (contrast_feature - contrast_feature.mean(0)) / contrast_feature.std(0) #torch.Size([256, 128])
            c= torch.matmul(anchor_feature.T, contrast_feature) 
            c.div_(batch_size)
            on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() # appr. 2~3
            off_diag = off_diagonal(c).pow_(2).sum()
                        
            loss = on_diag + 0.0051 * off_diag
            print("{a}--{b} LOSS: {c}".format(a=p,b=q,c= loss))
            total_loss += loss

0--1 LOSS: 130.9720458984375
0--2 LOSS: 130.98446655273438
0--3 LOSS: 129.42559814453125
1--0 LOSS: 130.9720458984375
1--2 LOSS: 129.5026397705078
1--3 LOSS: 128.0531768798828
2--0 LOSS: 130.98446655273438
2--1 LOSS: 129.5026397705078
2--3 LOSS: 130.7725067138672
3--0 LOSS: 129.42559814453125
3--1 LOSS: 128.0531768798828
3--2 LOSS: 130.7725067138672


In [146]:
import itertools

In [154]:
53.4333 * 6

320.5998

In [155]:
26.2334 * 12

314.8008

In [152]:
a= list(itertools.combinations(list(range(len(features))), 2))
if (1,2) in a:
    print("Yes")

Yes


In [158]:
len(a)

6

In [145]:
torch.combinations(torch.tensor([1,2,3]))

tensor([[1, 2],
        [1, 3],
        [2, 3]])

In [54]:
torch.unbind(features, dim=1)[1].shape

torch.Size([128, 128])

In [44]:
contrast_count = features.shape[1]
contrast_feature= torch.cat(torch.unbind(features, dim=1), dim=0)

In [55]:
contrast_feature

tensor([[0.7332, 0.0929, 0.2580,  ..., 0.3973, 0.9955, 0.8535],
        [0.0841, 0.8132, 0.7651,  ..., 0.0038, 0.9857, 0.8988],
        [0.0678, 0.3891, 0.4295,  ..., 0.6152, 0.9006, 0.6406],
        ...,
        [0.4338, 0.5726, 0.2865,  ..., 0.3523, 0.1264, 0.3898],
        [0.1844, 0.9726, 0.0948,  ..., 0.2757, 0.4321, 0.1781],
        [0.0125, 0.1082, 0.3442,  ..., 0.6851, 0.8039, 0.7882]])

In [46]:
anchor_feature = features[:, 0]

In [11]:
anchor_feature = contrast_feature
anchor_count = contrast_count

In [13]:
anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            0.07)

In [14]:
anchor_dot_contrast

tensor([[596.3107, 452.5212, 462.6601,  ..., 410.0806, 430.4772, 447.1158],
        [452.5212, 615.8022, 454.5374,  ..., 439.8339, 446.4165, 479.2466],
        [462.6601, 454.5374, 627.1229,  ..., 443.9037, 501.7697, 459.4048],
        ...,
        [410.0806, 439.8339, 443.9037,  ..., 558.7839, 472.4926, 447.3414],
        [430.4772, 446.4165, 501.7697,  ..., 472.4926, 632.8708, 477.7922],
        [447.1158, 479.2466, 459.4048,  ..., 447.3414, 477.7922, 637.1094]])

In [15]:
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)

In [17]:
logits= anchor_dot_contrast- logits_max.detach()

In [19]:
logits.shape

torch.Size([256, 256])

In [26]:
mask= mask.repeat(anchor_count, contrast_count)

In [30]:
logits_mask = torch.scatter(torch.ones_like(mask), 1, torch.arange(batch_size*anchor_count).view(-1,1),0)

In [31]:
logits_mask

tensor([[0., 1., 1.,  ..., 1., 1., 1.],
        [1., 0., 1.,  ..., 1., 1., 1.],
        [1., 1., 0.,  ..., 1., 1., 1.],
        ...,
        [1., 1., 1.,  ..., 0., 1., 1.],
        [1., 1., 1.,  ..., 1., 0., 1.],
        [1., 1., 1.,  ..., 1., 1., 0.]])

In [32]:
mask= mask*logits_mask

In [34]:
exp_logits = torch.exp(logits) * logits_mask

In [36]:
log_prob = torch.log( 1- exp_logits / (exp_logits.sum(1, keepdim=True)+1e-6) - 1e-6)

In [37]:
log_prob

tensor([[-1.0133e-06, -1.0133e-06, -1.0133e-06,  ..., -1.0133e-06,
         -1.0133e-06, -1.0133e-06],
        [-1.0133e-06, -1.0133e-06, -1.0133e-06,  ..., -1.0133e-06,
         -1.0133e-06, -1.0133e-06],
        [-1.0133e-06, -1.0133e-06, -1.0133e-06,  ..., -1.0133e-06,
         -1.0133e-06, -1.0133e-06],
        ...,
        [-1.0133e-06, -1.0133e-06, -1.0133e-06,  ..., -1.0133e-06,
         -1.0133e-06, -1.0133e-06],
        [-1.0133e-06, -1.0133e-06, -1.0133e-06,  ..., -1.0133e-06,
         -1.0133e-06, -1.0133e-06],
        [-1.0133e-06, -1.0133e-06, -1.0133e-06,  ..., -1.0133e-06,
         -1.0133e-06, -1.0133e-06]])