In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

class WDGRL(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        
        # Feature extractor
        self.feature_extractor = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, hidden_dim)
        )

        # Domain critic
        self.domain_critic = nn.Sequential(
            nn.Linear(hidden_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 2),
        )
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, output_dim),
        )
    
    def forward(self, x):
        
        extracted = self.feature_extractor(x)
        class_pred = self.classifier(extracted)
        
        return class_pred

In [21]:
from gen_data import *
from torch.utils.data import DataLoader, TensorDataset
from torch.autograd import grad
from tqdm.notebook import trange
import matplotlib.pyplot as plt

def loop_iterable(iterable):
    while True:
        yield from iterable

def gradient_penalty(critic, h_s, h_t):
    # based on: https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py#L116
    alpha = torch.rand(h_s.size(0), 1)
    differences = h_t - h_s
    interpolates = h_s + (alpha * differences)
    interpolates = torch.stack([interpolates, h_s, h_t]).requires_grad_()

    preds = critic(interpolates)
    gradients = grad(preds, interpolates,
                     grad_outputs=torch.ones_like(preds),
                     retain_graph=True, create_graph=True)[0]
    gradient_norm = gradients.norm(2, dim=1)
    gradient_penalty = ((gradient_norm - 1)**2).mean()
    return gradient_penalty

def train_WDGRL(input_dim, hidden_dim, output_dim=2,num_samples=1000, lr=0.001, 
             batch_size=128, max_epochs=1000000):
    
    batch_size = min(batch_size, num_samples)
    mu = 0
    delta = 2
    D_train_num = 10
    gamma = 10
    torch.autograd.set_detect_anomaly(True)

    X_s, y_s = gen_data(mu, delta, num_samples, input_dim)
    X_t, _ = gen_data(mu + 5, delta, num_samples, input_dim)

    with open(f'training_set/source_{input_dim}d.txt', 'w') as f:
        for i in range(len(X_s)):
            f.write(' '.join(map(str, X_s[i])) + ' ' + str(y_s[i]) + '\n')

    
    with open(f'training_set/target_{input_dim}d.txt', 'w') as f:
        for i in range(len(X_t)):
            f.write(' '.join(map(str, X_t[i])) + '\n')
    # training_set = []
    # with open(f'training_set/{input_dim}d.txt', 'r') as f:
    #     for line in f:
    #         training_set.append(list(map(float, line.strip().split())))

    X_s = torch.tensor(X_s, dtype=torch.float32)
    y_s = torch.tensor(y_s, dtype=torch.float32)
    source_dataset = TensorDataset(X_s, y_s)
    source_loader = DataLoader(source_dataset, batch_size=batch_size, shuffle=True)
    X_t = torch.tensor(X_t, dtype=torch.float32)
    target_dataset = TensorDataset(X_t)
    target_loader = DataLoader(target_dataset, batch_size=batch_size, shuffle=True)
    
    # model = WDGRL(input_dim, hidden_dim, output_dim)
    # critic_optim = optim.Adam(model.domain_critic.parameters(), lr=lr)
    # feature_optim = optim.Adam(model.feature_extractor.parameters(), lr=lr)
    # clf_optim = optim.Adam(model.classifier.parameters(), lr=lr)
    # clf_criterion = nn.CrossEntropyLoss()
    # # optimizer = optim.Adam(model.parameters(), lr=lr)
    
    # epoch_losses = []
    # epoch_loss = torch.inf
    # for i in range(max_epochs):
    #     batch_iterator = zip(loop_iterable(source_loader), loop_iterable(target_loader))

    #     total_loss = 0
    #     for _ in range(D_train_num):
    #         (xs, ys), (xt) = next(batch_iterator)
    #         xt = xt[0]
           
    #         h_s = model.feature_extractor(xs)
    #         h_t = model.feature_extractor(xt)

    #         for _ in range(D_train_num):
    #             gp = gradient_penalty(model.domain_critic, h_s, h_t)

    #             critic_s = model.domain_critic(h_s)
    #             critic_t = model.domain_critic(h_t)
    #             watterstein_distance = critic_s.mean() - critic_t.mean()
    #             critic_cost = -watterstein_distance + gamma * gp

    #             critic_optim.zero_grad()
    #             critic_cost.backward(retain_graph=True)
    #             critic_optim.step()
            
    #         loss = clf_criterion(h_s, ys.long())
    #         clf_optim.zero_grad()
    #         loss.backward(retain_graph=True)
    #         clf_optim.step()

    #         critic_s = model.domain_critic(h_s)
    #         critic_t = model.domain_critic(h_t)
    #         watterstein_distance = critic_s.mean() - critic_t.mean()
    #         critic_loss = -watterstein_distance + gamma * gp

    #         feature_extractor_loss = critic_loss + loss
    #         feature_optim.zero_grad()
    #         feature_extractor_loss.backward()
    #         feature_optim.step()
    #         print(f"Epoch {len(epoch_losses)}\t\tLoss: {loss:.4f}")

            

    
    # print("Final loss: {:.4f}".format(epoch_loss))
    # if epoch_loss > 0.05:
    #     print("Warning: Loss did not converge to desired value")

    # plt.plot(range(1, len(epoch_losses)+1), epoch_losses)
    # plt.xlabel('Epoch')
    # plt.ylabel('Loss')
    # plt.title(f'Training Loss ({input_dim}D -> {hidden_dim}D)')
    # plt.savefig(f"loss/{input_dim}d_{hidden_dim}d.png")
    # plt.close()
    # # Save the trained model
    # torch.save(model.state_dict(), f"model/{input_dim}d_{hidden_dim}d.pth")
    # print(f"Model saved as 'model/{input_dim}d_{hidden_dim}d.pth'")

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

class FeatureExtractor(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(FeatureExtractor, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
    def forward(self, x):
        return self.network(x)

class Classifier(nn.Module):
    def __init__(self, hidden_dim, num_classes):
        super(Classifier, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_classes)
        )
        
    def forward(self, x):
        return self.network(x)

class DomainDiscriminator(nn.Module):
    def __init__(self, hidden_dim):
        super(DomainDiscriminator, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        
    def forward(self, x):
        return self.network(x)

class WDGRL:
    def __init__(self, input_dim, hidden_dim, num_classes, lambda_wd=1.0, n_critic=5):
        self.feature_extractor = FeatureExtractor(input_dim, hidden_dim)
        self.classifier = Classifier(hidden_dim, num_classes)
        self.domain_discriminator = DomainDiscriminator(hidden_dim)
        
        self.lambda_wd = lambda_wd
        self.n_critic = n_critic
        
        # Optimizers
        self.opt_feat_class = optim.Adam(list(self.feature_extractor.parameters()) + 
                                       list(self.classifier.parameters()))
        self.opt_disc = optim.Adam(self.domain_discriminator.parameters())
        
        self.criterion = nn.CrossEntropyLoss()
    
    def compute_wasserstein_distance(self, source_features, target_features):
        # Compute Wasserstein distance using domain discriminator
        source_pred = self.domain_discriminator(source_features)
        target_pred = self.domain_discriminator(target_features)
        return torch.mean(source_pred) - torch.mean(target_pred)
    
    def train_step(self, source_data, source_labels, target_data):
        # Extract features
        source_features = self.feature_extractor(source_data)
        target_features = self.feature_extractor(target_data)
        
        # Train domain discriminator
        for _ in range(self.n_critic):
            self.opt_disc.zero_grad()
            wasserstein_distance = self.compute_wasserstein_distance(source_features.detach(),
                                                                   target_features.detach())
            
            # Gradient penalty
            alpha = torch.rand(source_features.size(0), 1)
            interpolates = alpha * source_features.detach() + (1 - alpha) * target_features.detach()
            interpolates.requires_grad = True
            disc_interpolates = self.domain_discriminator(interpolates)
            gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                                          grad_outputs=torch.ones_like(disc_interpolates),
                                          create_graph=True, retain_graph=True)[0]
            gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
            
            disc_loss = -wasserstein_distance + 10 * gradient_penalty
            disc_loss.backward()
            self.opt_disc.step()
        
        # Train feature extractor and classifier
        self.opt_feat_class.zero_grad()
        
        # Classification loss
        source_logits = self.classifier(source_features)
        class_loss = self.criterion(source_logits, source_labels)
        
        # Domain adaptation loss
        wasserstein_distance = self.compute_wasserstein_distance(source_features, target_features)
        total_loss = class_loss - self.lambda_wd * wasserstein_distance
        
        total_loss.backward()
        self.opt_feat_class.step()
        
        return {
            'classification_loss': class_loss.item(),
            'wasserstein_distance': wasserstein_distance.item(),
            'total_loss': total_loss.item()
        }
    
    def predict(self, x):
        self.feature_extractor.eval()
        self.classifier.eval()
        with torch.no_grad():
            features = self.feature_extractor(x)
            logits = self.classifier(features)
        return torch.argmax(logits, dim=1)

# Example usage
def train_wdgrl(source_loader, target_loader, input_dim, hidden_dim, num_classes, num_epochs=100):
    model = WDGRL(input_dim=input_dim, 
                  hidden_dim=hidden_dim,
                  num_classes=num_classes)
    
    for epoch in range(num_epochs):
        for (source_data, source_labels), (target_data) in zip(source_loader, target_loader):
            losses = model.train_step(source_data, source_labels, target_data)
            
            if epoch % 10 == 0:
                print(f"Epoch {epoch}")
                print(f"Classification Loss: {losses['classification_loss']:.4f}")
                print(f"Wasserstein Distance: {losses['wasserstein_distance']:.4f}")
                print(f"Total Loss: {losses['total_loss']:.4f}")
                print("-" * 50)
    
    return model

  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "C:\Users\Admin\AppData\Roaming\Python\Python312\site-packages\ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "C:\Users\Admin\AppData\Roaming\Python\Python312\site-packages\traitlets\config\application.py", line 1075, in launch_instance
    app.start()
  File "C:\Users\Admin\AppData\Roaming\Python\Python312\site-packages\ipykernel\kernelapp.py", line 739, in start
    self.io_loop.start()
  File "C:\Users\Admin\AppData\Roaming\Python\Python312\site-packages\tornado\platform\asyncio.py", line 205, in start
    self.asyncio_loop.run_forever()
  File "c:\Users\Admin\AppData\Local\Programs\Python\Python312\Lib\asyncio\base_events.py", line 641, in run_forever
    self._run_once()
  File "c:\Users\Admin\AppData\Local\Programs\Python\Python312\Lib\asyncio\base_events.py", line 1986, in _run_once
    handle._run()
  File "c:\Users\Admin\AppData\Loc

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [32, 3]], which is output 0 of AsStridedBackward0, is at version 11; expected version 10 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!