In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, ConcatDataset
from torch.utils.data import Dataset
from torch.utils.data import random_split
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pickle
import os
from PIL import Image
import itertools
from mnistm import MNISTMDataset
from syn_digits import SyntheticDigits

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.get_device_name()

'NVIDIA GeForce RTX 3080'

# Data Preprocessing

In [None]:
# Define transformations
transform_mnist = transforms.Compose([
    transforms.Resize((32, 32)),  # Resize to the same size as used in SVHN)
    transforms.Grayscale(num_output_channels=3),  # Convert to 3-channel RGB
    transforms.ToTensor(),  # Convert to Tensor
    transforms.Normalize((0.5,), (0.5,))  # Normalize (assuming grayscale, same value for all channels)
])

transform_mnist_m = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize for RGB
])

transform_svhn = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

transform_syn = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])


In [None]:
# Load datasets
mnist_train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform_mnist) # 60000
mnist_test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform_mnist) # 10000

mnistm_train_dataset = MNISTMDataset("./dataset", train = True, transform=transform_mnist_m) # 59001
mnistm_test_dataset = MNISTMDataset("./dataset", train = False, transform=transform_mnist_m) # 9001

svhn_train_dataset = datasets.SVHN(root='./data', split='train', download=True, transform=transform_svhn) # 73257
svhn_test_dataset = datasets.SVHN(root='./data', split='test', download=True, transform=transform_svhn) # 26032

syn_train_dataset = SyntheticDigits(root='./syn_dataset', train=True, transform=transform_syn, target_transform=None, download=True)
syn_test_dataset = SyntheticDigits(root='./syn_dataset', train=False, transform=transform_syn, target_transform=None, download=True)

Using downloaded and verified file: ./data\train_32x32.mat
Using downloaded and verified file: ./data\test_32x32.mat
./syn_dataset\SyntheticDigits\processed\synth_train.pt
./syn_dataset\SyntheticDigits\processed\synth_test.pt


In [None]:
batch_size = 64  # Set the batch size

total_size = len(mnist_train_dataset)
val_size = int(0.15 * total_size)
train_size = total_size - val_size
train_dataset, val_dataset = random_split(mnist_train_dataset, [train_size, val_size])

mnist_train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
mnist_val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
mnist_test_loader = DataLoader(mnist_test_dataset, batch_size=batch_size, shuffle=False)

total_size = len(mnistm_train_dataset)
val_size = int(0.15 * total_size)
train_size = total_size - val_size
train_dataset, val_dataset = random_split(mnistm_train_dataset, [train_size, val_size])

mnistm_train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
mnistm_val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
mnistm_test_loader = DataLoader(mnistm_test_dataset, batch_size=batch_size, shuffle=False)

total_size = len(svhn_train_dataset)
val_size = int(0.15 * total_size)
train_size = total_size - val_size
train_dataset, val_dataset = random_split(svhn_train_dataset, [train_size, val_size])

svhn_train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
svhn_val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
svhn_test_loader = DataLoader(svhn_test_dataset, batch_size=batch_size, shuffle=False)

total_size = len(syn_train_dataset)
val_size = int(0.15 * total_size)
train_size = total_size - val_size
train_dataset, val_dataset = random_split(syn_train_dataset, [train_size, val_size])

syn_train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
syn_val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
syn_test_loader = DataLoader(syn_test_dataset, batch_size=batch_size, shuffle=False)

# Model Architecture

In [None]:

# Feature extractor (e.g., a simple CNN for image tasks)
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3)  # Assuming input images are RGB (3 channels)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 32, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.drop = nn.Dropout2d()
        self.conv3 = nn.Conv2d(32, 64, 3)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 64, 3)
        # Calculate the size of the flattened features after the conv and pooling layers
        # After the pooling layers, a 32x32 image becomes 5x5
        self.fc1 = nn.Linear(64 * 5 * 5, 128)
        self.fc2 = nn.Linear(128, 64)

    def forward(self, x):
        x = self.bn1(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.drop(x)
        x = self.bn2(F.relu(self.conv3(x)))
        x = self.pool(F.relu(self.conv4(x)))
        x = self.drop(x)
        x = x.view(-1, 64 * 5 * 5)  # Flatten the output
        feature1 = x
        x = F.relu(self.fc1(x))
        feature2 = x
        x = F.relu(self.fc2(x))
        return feature1, feature2, x  # This feature vector is passed to the classifier

class Classifier(nn.Module):
    def __init__(self, num_classes=10):
        super(Classifier, self).__init__()
        # The input size should match the size of the feature vector from the feature extractor
        #self.fc1 = nn.Linear(84, 120)
        #self.fc2 = nn.Linear(120, 84)
        # The output size should match the number of classes in the classification task
        self.fc3 = nn.Linear(64, num_classes)

    def forward(self, x):
        #x = F.relu(self.fc1(x))
        #x = F.relu(self.fc2(x))
        x = self.fc3(x) # Logits for each class
        return F.log_softmax(x, dim=1)


In [None]:
class MMDLoss(nn.Module):
    def __init__(self, kernel_type='rbf', kernel_mul=2.0, kernel_num=5):
        super(MMDLoss, self).__init__()
        self.kernel_num = kernel_num
        self.kernel_mul = kernel_mul
        self.fix_sigma = None
        self.kernel_type = kernel_type

    def guassian_kernel(self, source, target, kernel_mul, kernel_num, fix_sigma):
        n_samples = int(source.size()[0])+int(target.size()[0])
        total = torch.cat([source, target], dim=0)

        total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
        total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
        L2_distance = ((total0-total1)**2).sum(2)
        L2_distance = torch.clamp(L2_distance, min=1e-8)
        #print(L2_distance)
        if fix_sigma:
            bandwidth = fix_sigma
        else:
            bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)

        bandwidth /= kernel_mul ** (kernel_num // 2)
        bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
        kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
        return sum(kernel_val)

    def forward(self, source, target):
        if self.kernel_type == 'rbf':
            batch_size = int(source.size()[0])
            kernels = self.guassian_kernel(source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma)
            XX = torch.mean(kernels[:batch_size, :batch_size])
            #print("XX: ", XX)
            YY = torch.mean(kernels[batch_size:, batch_size:])
            #print("YY: ", YY)
            XY = torch.mean(kernels[:batch_size, batch_size:])
            #print("XY: ", XY)
            YX = torch.mean(kernels[batch_size:, :batch_size])
            #print("YX: ", YX)
            loss = torch.mean(XX + YY - XY - YX)
            #print("loss: ", loss)
            return loss

In [None]:
feature_extractor = FeatureExtractor().to(device)
classifier = Classifier().to(device)
mmd_loss = MMDLoss().to(device)

# Standard classification loss
criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(list(feature_extractor.parameters()) + list(classifier.parameters()), lr=0.001)


In [None]:
def train(epoch, source_loader, target_loader):
    feature_extractor.train()
    classifier.train()

    total_loss = 0
    correct = 0
    total = 0

    for batch_idx, (source_data, target_data) in enumerate(zip(source_loader, target_loader)):
        source_inputs, source_labels = source_data
        # print(source_inputs.shape)
        target_inputs, _ = target_data  # Target labels are not used
        #print(target_inputs.shape)
        source_inputs, source_labels = source_inputs.to(device), source_labels.to(device)
        target_inputs = target_inputs.to(device)

        optimizer.zero_grad()
        # Forward pass through the feature extractor and classifier
        source_f1, source_f2, source_features = feature_extractor(source_inputs)
        class_outputs = classifier(source_features)
        # print(class_outputs.shape)
        tar_f1, tar_f2, target_features = feature_extractor(target_inputs)

        # Compute classification loss on source domain
        cls_loss = criterion(class_outputs, source_labels)

        #if torch.isnan(target_features).any():
            #print("NaNs in feature extractor output")

        # Compute MMD loss
        mmd = mmd_loss(source_f1.double(), tar_f1.double())
        mmd += mmd_loss(source_f2.double(), tar_f2.double())
        mmd += mmd_loss(source_features.double(), target_features.double())

        # Combine losses
        loss = cls_loss + 0.1*mmd
        #print("loss: ", loss)
        total_loss += loss

        #print(total_loss)
        # Backward pass
        loss.backward()
        optimizer.step()

        # Compute accuracy
        _, predicted = torch.max(class_outputs.data, 1)
        total += source_labels.size(0)
        correct += (predicted == source_labels).sum().item()

    print('Train Epoch: {} \tLoss: {:.6f} \tAccuracy: {:.2f}%'.format(
        epoch, total_loss / (batch_idx + 1), 100. * correct / total))


In [None]:
def validate(loader):
    feature_extractor.eval()
    classifier.eval()

    test_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            _, _, features = feature_extractor(inputs)
            outputs = classifier(features)
            loss = criterion(outputs, labels)

            test_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    test_loss /= len(loader)
    accuracy = 100. * correct / total
    print('\nValidation set: Average loss: {:.4f}, Accuracy: {:.2f}%\n'.format(test_loss, accuracy))


In [None]:
for epoch in range(1, 11):
    train(epoch, source_loader=syn_train_loader, target_loader=svhn_train_loader)
    # scheduler.step()
    validate(svhn_val_loader)  # Assuming target_loader is your validation set
validate(svhn_test_loader)  # Run the validation function on the test set

Train Epoch: 1 	Loss: 0.579821 	Accuracy: 82.13%

Validation set: Average loss: 0.8252, Accuracy: 74.22%

Train Epoch: 2 	Loss: 0.263379 	Accuracy: 92.83%

Validation set: Average loss: 0.7837, Accuracy: 76.83%

Train Epoch: 3 	Loss: 0.213972 	Accuracy: 94.29%

Validation set: Average loss: 0.7101, Accuracy: 79.09%

Train Epoch: 4 	Loss: 0.193516 	Accuracy: 94.97%

Validation set: Average loss: 0.7692, Accuracy: 78.48%

Train Epoch: 5 	Loss: 0.178023 	Accuracy: 95.42%

Validation set: Average loss: 0.7016, Accuracy: 79.03%

Train Epoch: 6 	Loss: 0.167565 	Accuracy: 95.72%

Validation set: Average loss: 0.7037, Accuracy: 80.00%

Train Epoch: 7 	Loss: 0.153118 	Accuracy: 96.18%

Validation set: Average loss: 0.7150, Accuracy: 79.90%

Train Epoch: 8 	Loss: 0.146427 	Accuracy: 96.30%

Validation set: Average loss: 0.7393, Accuracy: 79.44%

Train Epoch: 9 	Loss: 0.141679 	Accuracy: 96.66%

Validation set: Average loss: 0.6898, Accuracy: 79.91%

Train Epoch: 10 	Loss: 0.138052 	Accuracy: 96.

In [None]:
validate(mnist_test_loader)


Validation set: Average loss: 0.6621, Accuracy: 85.80%



In [None]:
validate(mnistm_test_loader)


Validation set: Average loss: 1.7577, Accuracy: 55.04%

