## baseline model

In [1]:
import random
import os
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torchvision import datasets
from torchvision import transforms
import numpy as np


In [2]:
import torch.utils.data as data
from PIL import Image
import os
import torch.nn as nn


class GetLoader(data.Dataset):
    def __init__(self, data_root, data_list, transform=None):
        self.root = data_root
        self.transform = transform

        f = open(data_list, 'r')
        data_list = f.readlines()
        f.close()

        self.n_data = len(data_list)

        self.img_paths = []
        self.img_labels = []

        for data in data_list:
            self.img_paths.append(data[:-3])
            self.img_labels.append(data[-2])

    def __getitem__(self, item):
        img_paths, labels = self.img_paths[item], self.img_labels[item]
        imgs = Image.open(os.path.join(self.root, img_paths)).convert('RGB')

        if self.transform is not None:
            imgs = self.transform(imgs)
            labels = int(labels)

        return imgs, labels

    def __len__(self):
        return self.n_data

In [3]:
import os
import warnings

import torch
from PIL import Image
from torchvision.datasets import VisionDataset
from torchvision.datasets.utils import download_and_extract_archive


class SyntheticDigits(VisionDataset):
    """Synthetic Digits Dataset.
    """

    resources = [
        ('https://github.com/liyxi/synthetic-digits/releases/download/data/synth_train.pt.gz',
         'd0e99daf379597e57448a89fc37ae5cf'),
        ('https://github.com/liyxi/synthetic-digits/releases/download/data/synth_test.pt.gz',
         '669d94c04d1c91552103e9aded0ee625')
    ]

    training_file = "synth_train.pt"
    test_file = "synth_test.pt"
    classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
               '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']

    @property
    def train_labels(self):
        warnings.warn("train_labels has been renamed targets")
        return self.targets

    @property
    def test_labels(self):
        warnings.warn("test_labels has been renamed targets")
        return self.targets

    @property
    def train_data(self):
        warnings.warn("train_data has been renamed data")
        return self.data

    @property
    def test_data(self):
        warnings.warn("test_data has been renamed data")
        return self.data

    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        """Init Synthetic Digits dataset."""
        super(SyntheticDigits, self).__init__(root, transform=transform, target_transform=target_transform)

        self.train = train

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError("Dataset not found." +
                               " You can use download=True to download it")

        if self.train:
            data_file = self.training_file
        else:
            data_file = self.test_file

        print(os.path.join(self.processed_folder, data_file))

        self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))

    def __getitem__(self, index):
        """Get images and target for data loader.
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], int(self.targets[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img.squeeze().numpy(), mode="RGB")

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        """Return size of dataset."""
        return len(self.data)

    @property
    def raw_folder(self):
        return os.path.join(self.root, self.__class__.__name__, 'raw')

    @property
    def processed_folder(self):
        return os.path.join(self.root, self.__class__.__name__, 'processed')

    @property
    def class_to_idx(self):
        return {_class: i for i, _class in enumerate(self.classes)}

    def _check_exists(self):
        return (os.path.exists(os.path.join(self.processed_folder, self.training_file)) and
                os.path.exists(os.path.join(self.processed_folder, self.test_file)))

    def download(self):
        """Download the Synthetic Digits data."""

        if self._check_exists():
            return

        os.makedirs(self.raw_folder, exist_ok=True)
        os.makedirs(self.processed_folder, exist_ok=True)

        # download files
        for url, md5 in self.resources:
            filename = url.rpartition('/')[2]
            download_and_extract_archive(url, download_root=self.raw_folder,
                                         extract_root=self.processed_folder,
                                         filename=filename, md5=md5)

        print('Done!')

    def extra_repr(self):
        return "Split: {}".format("Train" if self.train is True else "Test")

In [4]:
transform_syn = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(32),
    transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970))
])

# Transform for SVHN
transform_svhn = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Resize(28),
    transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970))
])

# Transform for MNIST
transform_mnist = transforms.Compose([
    transforms.Resize(32),  # Resize MNIST images to 32x32 to match SVHN format
    transforms.Grayscale(3),  # Convert MNIST images to 3 channels
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Transform for MNIST-M
transform_mnistm = transforms.Compose([
    transforms.Resize((32, 32)),  # Resize MNIST images to 32x32 to match SVHN format
    transforms.Grayscale(3),  # Convert MNIST images to 3 channels
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
batch_size = 128

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# # Load target SVHN dataset
source_svhn_dataset = datasets.SVHN(root='./data', split='train', download=True, transform=transform_svhn)
dataloader_svhn_source = DataLoader(source_svhn_dataset, batch_size=batch_size, shuffle=True)

# Load source Syn Digits dataset
# source_syn_dataset = SyntheticDigits(root='./syn_dataset', train=True, download=True, transform=transform_syn)
# dataloader_syn_source = DataLoader(source_syn_dataset, batch_size=batch_size, shuffle=True)

# Load source MNIST Digits dataset
# source_mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform_mnist)
# dataloader_mnist_source = DataLoader(source_mnist_dataset, batch_size=batch_size, shuffle=True)

Downloading http://ufldl.stanford.edu/housenumbers/train_32x32.mat to ./data/train_32x32.mat


100%|██████████| 182040794/182040794 [00:31<00:00, 5712626.39it/s]


In [None]:
mn = DataLoader(datasets.MNIST(root='./data', train=True, download=True, transform=transform_mnist), batch_size=128, shuffle=True)
sy = DataLoader(SyntheticDigits(root='./data', train=True, download=True, transform=transform_syn), batch_size=128, shuffle=True)

for i in mn:
  p, q = i
  print(p.shape)
  break
for j in sy:
  p, q = i
  print(p.shape)
  break

In [None]:
# import tarfile

# with tarfile.open('mnist_m.tar', "r:gz") as tar:
#     # Extract all the contents into the current directory
#     tar.extractall()

!tar -xvf  'mnist_m.tar'

In [7]:
for data in dataloader_svhn_source:
    imgs, labels = data
    print(imgs.shape)
    print(labels.shape)
    break

torch.Size([128, 3, 32, 32])
torch.Size([128])


In [16]:
import torch.nn as nn
from torch.autograd import Function

class ReverseLayerF(Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha
        return output, None

class CNNModel(nn.Module):
    def __init__(self, input_channels=3, num_classes=10, domain_classes=2):
        super(CNNModel, self).__init__()

        # Feature extractor
        self.feature = nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=5),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2),
            nn.ReLU(True),
            nn.Conv2d(64, 50, kernel_size=5),
            nn.BatchNorm2d(50),
            nn.Dropout2d(),
            nn.MaxPool2d(2),
            nn.ReLU(True)
        )

        # Class classifier
        self.class_classifier = nn.Sequential(
            nn.Linear(50 * 5 * 5, 100),
            nn.BatchNorm1d(100),
            nn.ReLU(True),
            nn.Dropout2d(),
            nn.Linear(100, 100),
            nn.BatchNorm1d(100),
            nn.ReLU(True),
            nn.Linear(100, num_classes),
            nn.LogSoftmax(dim=1)
        )


    def forward(self, input_data, alpha):
      feature = self.feature(input_data)

    # Calculate the size of the flattened features
      size = feature.size()[1:]  # all dimensions except the batch dimension
      num_features = 1
      for s in size:
          num_features *= s

      feature = feature.view(-1, num_features)  # Automatically infer batch size
      class_output = self.class_classifier(feature)

      return class_output

In [None]:
transform_syn = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(32),
    transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970))
])

# Transform for SVHN
transform_svhn = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Resize(28),
    transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970))
])

# Transform for MNIST
transform_mnist = transforms.Compose([
    # transforms.Resize((32, 32)),  # Resize MNIST images to 32x32 to match SVHN format
    transforms.Grayscale(3),  # Convert MNIST images to 3 channels
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Transform for MNIST-M
transform_mnistm = transforms.Compose([
    transforms.Resize((32, 32)),  # Resize MNIST images to 32x32 to match SVHN format
    transforms.Grayscale(3),  # Convert MNIST images to 3 channels
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
batch_size = 128

In [9]:
import os
import torch
import torch.backends.cudnn as cudnn
import torch.utils.data
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# ... (other necessary imports)

def test(dataset_name, epoch):
    assert dataset_name in ['MNIST', 'mnist_m', 'svhn', 'syn']

    model_root = 'mnist_model_epoch.pth'
    image_root = 'mnist_m'

    cuda = True
    cudnn.benchmark = True
    batch_size = 128
    image_size = 32
    alpha = 0

    """ Load Data """
    transform_mnistm = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])
    transform_mnist = transforms.Compose([
        transforms.Resize(image_size),  # Resize MNIST images to 32x32 to match SVHN format
        transforms.Grayscale(3),  # Convert MNIST images to 3 channels
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])


    if dataset_name == 'mnist_m':
        test_list = os.path.join(image_root, 'mnist_m_test_labels.txt')
        dataset = GetLoader(
            data_root=os.path.join(image_root, 'mnist_m_test'),
            data_list=test_list,
            transform=transform_mnistm
        )
    elif dataset_name == 'svhn':
        transform_svhn = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970))
        ])
        dataset = datasets.SVHN(root='./data', split='test', download=True, transform=transform_svhn)
    elif dataset_name == 'syn':
        transform_syn = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(32),
            transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970))
        ])
        dataset = SyntheticDigits(root='./syn_dataset', train=False, download=True, transform=transform_syn)

    else:  # MNIST
        dataset = datasets.MNIST(root='./data', train=False, transform=transform_mnist, download = True)

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    """ Evaluate Model """
    my_net = torch.load(model_root)
    my_net = my_net.eval()

    if cuda:
        my_net = my_net.cuda()

    n_total = 0
    n_correct = 0

    with torch.no_grad():
        for data in dataloader:
            t_img, t_label = data
            if cuda:
                t_img = t_img.cuda()
                t_label = t_label.cuda()

            class_output = my_net(input_data=t_img, alpha=alpha)

            pred = class_output.data.max(1, keepdim=True)[1]
            n_correct += pred.eq(t_label.data.view_as(pred)).cpu().sum()
            n_total += t_label.size(0)

    accu = n_correct.item() * 1.0 / n_total
    print(f'epoch: {epoch}, accuracy of the {dataset_name} dataset: {accu}')


In [11]:
test('syn',1)

./syn_dataset/SyntheticDigits/processed/synth_test.pt




epoch: 1, accuracy of the syn dataset: 0.3073380090024076


In [20]:
cuda = True
cudnn.benchmark = True
lr = 1e-3
batch_size = 128
image_size = 32
n_epoch = 50

manual_seed = 2023
random.seed(manual_seed)
torch.manual_seed(manual_seed)


my_net = CNNModel()
optimizer = optim.Adam(my_net.parameters(), lr=lr)

loss_class = torch.nn.NLLLoss()

if cuda:
    my_net = my_net.cuda()
    loss_class = loss_class.cuda()

for p in my_net.parameters():
    p.requires_grad = True



In [21]:
syn_source_acc=[]
mnist_target_acc=[]
mnistm_target_acc = []
svhn_target_acc = []

for epoch in range(n_epoch):

    len_dataloader = len(dataloader_svhn_source)
    data_source_iter = iter(dataloader_svhn_source)
    # data_target_iter = iter(dataloader_target)

    i = 0
    while i < len_dataloader:

        p = float(i + epoch * len_dataloader) / n_epoch / len_dataloader
        alpha = 2. / (1. + np.exp(-10 * p)) - 1

        # training model using source data
        data_source = next(data_source_iter)
        s_img, s_label = data_source

        my_net.zero_grad()
        batch_size = len(s_label)
        #print(batch_size)

        input_img = torch.FloatTensor(batch_size, 3, image_size, image_size)
        class_label = torch.LongTensor(batch_size)
        domain_label = torch.zeros(batch_size)
        domain_label = domain_label.long()

        if cuda:
            s_img = s_img.cuda()
            s_label = s_label.cuda()
            input_img = input_img.cuda()
            class_label = class_label.cuda()

        input_img.resize_as_(s_img).copy_(s_img)
        class_label.resize_as_(s_label).copy_(s_label)
        #print(input_img.shape)
        #print(class_label.shape)

        class_output = my_net(input_data=input_img, alpha=alpha)
        #print(class_output.shape)
        #print(domain_output.shape)
        err_s_label = loss_class(class_output, class_label)

        # training model using target data

        err = err_s_label
        err.backward()
        optimizer.step()

        i += 1

    print('epoch: %d, [iter: %d / all %d], err_s_label: %f' % (epoch, i, len_dataloader, err_s_label.cpu().data.numpy()))

    torch.save(my_net, '../mnist_model_epoch.pth')
    # mnist=test('MNIST', epoch)
    # mnist_target_acc.append(mnist)

    # mnistm=test('mnist_m', epoch)
    # mnistm_target_acc.append(mnistm)

    # svhn = test('svhn',epoch)
    # svhn_target_acc.append(svhn)

    syn = test('syn', epoch)
    syn_source_acc.append(syn)



epoch: 0, [iter: 573 / all 573], err_s_label: 0.936699
./syn_dataset/SyntheticDigits/processed/synth_test.pt
epoch: 0, accuracy of the syn dataset: 0.3073380090024076
epoch: 1, [iter: 573 / all 573], err_s_label: 0.678260
./syn_dataset/SyntheticDigits/processed/synth_test.pt
epoch: 1, accuracy of the syn dataset: 0.3073380090024076
epoch: 2, [iter: 573 / all 573], err_s_label: 0.878667
./syn_dataset/SyntheticDigits/processed/synth_test.pt
epoch: 2, accuracy of the syn dataset: 0.3073380090024076
epoch: 3, [iter: 573 / all 573], err_s_label: 0.665408
./syn_dataset/SyntheticDigits/processed/synth_test.pt
epoch: 3, accuracy of the syn dataset: 0.3073380090024076
epoch: 4, [iter: 573 / all 573], err_s_label: 0.624948
./syn_dataset/SyntheticDigits/processed/synth_test.pt
epoch: 4, accuracy of the syn dataset: 0.3073380090024076
epoch: 5, [iter: 573 / all 573], err_s_label: 0.404019
./syn_dataset/SyntheticDigits/processed/synth_test.pt
epoch: 5, accuracy of the syn dataset: 0.307338009002407

## SVHN trained prediction on mnist,mnist-m

In [None]:
test('MNIST',1)
test('mnist_m',2)
test('svhn',3)

## MNIST trained prediction on mnist-m,svhn

In [None]:
test('MNIST',1)
test('mnist_m',2)
test('svhn',3)