In [None]:
! python3 -m pip install transformers

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from numpy import newaxis
import math

import os
import pandas as pd
import torch.nn as nn
from scipy.stats import chi2
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModel
from sklearn.preprocessing import OneHotEncoder

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
# configuration for training, you should modify these values to get the best performance
config = {
    "num_labels": 6,
    "hidden_dropout_prob": 0.15,
    "hidden_size": 768,
    "max_length": 512,
}

training_parameters = {
    "batch_size": 16,
    "epochs": 15,
    "output_folder": "/kaggle/working",
    "output_file": "model.bin",
    "learning_rate": 2e-5,
    "print_after_steps": 100,
    "save_steps": 5000,

}

## Class for preprocess dataset

In [None]:
class ReviewDataset(Dataset):
    def __init__(self, df):
        self.df = df
        self.tokenizer = AutoTokenizer.from_pretrained('jackaduma/SecBERT')

    def __getitem__(self, index):
        review = self.df.iloc[index]["text"]
        attack = self.df.iloc[index]["label"]
        attack_dict = {'Injection': 0,
          'Manipulation': 1,
          'Scanning for Vulnerable Software': 2,
          'HTTP abusion': 3,
          'Fake the Source of Data': 4,
                      'Normal': 5}
        label = attack_dict[attack]
        encoded_input = self.tokenizer.encode_plus(
                review,
                add_special_tokens=True,
                max_length = 512,
                padding="max_length",
                return_overflowing_tokens=True,
                truncation = True,
            )
        if "num_truncated_tokens" in encoded_input and encoded_input["num_truncated_tokens"] > 0:
            # print("Attention! you are cropping tokens")
            pass

        input_ids = encoded_input["input_ids"]
        attention_mask = encoded_input["attention_mask"] if "attention_mask" in encoded_input else None

        token_type_ids = encoded_input["token_type_ids"] if "token_type_ids" in encoded_input else None



        data_input = {
            "input_ids": torch.tensor(input_ids),
            "attention_mask": torch.tensor(attention_mask),
            "token_type_ids": torch.tensor(token_type_ids),
            "label": torch.tensor(label),
        }

        return data_input["input_ids"], data_input["attention_mask"], data_input["token_type_ids"], data_input["label"]



    def __len__(self):
        return self.df.shape[0]

## Class for MMD implementation

In [None]:
from typing import Optional, Sequence

class GaussianKernel(nn.Module):
    r"""Gaussian Kernel Matrix
    Gaussian Kernel k is defined by
    .. math::
        k(x_1, x_2) = \exp \left( - \dfrac{\| x_1 - x_2 \|^2}{2\sigma^2} \right)
    where :math:`x_1, x_2 \in R^d` are 1-d tensors.
    Gaussian Kernel Matrix K is defined on input group :math:`X=(x_1, x_2, ..., x_m),`
    .. math::
        K(X)_{i,j} = k(x_i, x_j)
    Also by default, during training this layer keeps running estimates of the
    mean of L2 distances, which are then used to set hyperparameter  :math:`\sigma`.
    Mathematically, the estimation is :math:`\sigma^2 = \dfrac{\alpha}{n^2}\sum_{i,j} \| x_i - x_j \|^2`.
    If :attr:`track_running_stats` is set to ``False``, this layer then does not
    keep running estimates, and use a fixed :math:`\sigma` instead.
    Args:
        sigma (float, optional): bandwidth :math:`\sigma`. Default: None
        track_running_stats (bool, optional): If ``True``, this module tracks the running mean of :math:`\sigma^2`.
          Otherwise, it won't track such statistics and always uses fix :math:`\sigma^2`. Default: ``True``
        alpha (float, optional): :math:`\alpha` which decides the magnitude of :math:`\sigma^2` when track_running_stats is set to ``True``
    Inputs:
        - X (tensor): input group :math:`X`
    Shape:
        - Inputs: :math:`(minibatch, F)` where F means the dimension of input features.
        - Outputs: :math:`(minibatch, minibatch)`
    """

    def __init__(self, sigma: Optional[float] = None, track_running_stats: Optional[bool] = True,
                 alpha: Optional[float] = 1.):
        super(GaussianKernel, self).__init__()
        assert track_running_stats or sigma is not None
        self.sigma_square = torch.tensor(sigma * sigma) if sigma is not None else None
        self.track_running_stats = track_running_stats
        self.alpha = alpha

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        l2_distance_square = ((X.unsqueeze(0) - X.unsqueeze(1)) ** 2).sum(2)

        if self.track_running_stats:
            self.sigma_square = self.alpha * torch.mean(l2_distance_square.detach())

        return torch.exp(-l2_distance_square / (2 * self.sigma_square))

class MultipleKernelMaximumMeanDiscrepancy(nn.Module):
    r"""The Multiple Kernel Maximum Mean Discrepancy (MK-MMD) used in
    `Learning Transferable Features with Deep Adaptation Networks (ICML 2015) <https://arxiv.org/pdf/1502.02791>`_
    Given source domain :math:`\mathcal{D}_s` of :math:`n_s` labeled points and target domain :math:`\mathcal{D}_t`
    of :math:`n_t` unlabeled points drawn i.i.d. from P and Q respectively, the deep networks will generate
    activations as :math:`\{z_i^s\}_{i=1}^{n_s}` and :math:`\{z_i^t\}_{i=1}^{n_t}`.
    The MK-MMD :math:`D_k (P, Q)` between probability distributions P and Q is defined as
    .. math::
        D_k(P, Q) \triangleq \| E_p [\phi(z^s)] - E_q [\phi(z^t)] \|^2_{\mathcal{H}_k},
    :math:`k` is a kernel function in the function space
    .. math::
        \mathcal{K} \triangleq \{ k=\sum_{u=1}^{m}\beta_{u} k_{u} \}
    where :math:`k_{u}` is a single kernel.
    Using kernel trick, MK-MMD can be computed as
    .. math::
        \hat{D}_k(P, Q) &=
        \dfrac{1}{n_s^2} \sum_{i=1}^{n_s}\sum_{j=1}^{n_s} k(z_i^{s}, z_j^{s})\\
        &+ \dfrac{1}{n_t^2} \sum_{i=1}^{n_t}\sum_{j=1}^{n_t} k(z_i^{t}, z_j^{t})\\
        &- \dfrac{2}{n_s n_t} \sum_{i=1}^{n_s}\sum_{j=1}^{n_t} k(z_i^{s}, z_j^{t}).\\
    Args:
        kernels (tuple(torch.nn.Module)): kernel functions.
        linear (bool): whether use the linear version of DAN. Default: False
    Inputs:
        - z_s (tensor): activations from the source domain, :math:`z^s`
        - z_t (tensor): activations from the target domain, :math:`z^t`
    Shape:
        - Inputs: :math:`(minibatch, *)`  where * means any dimension
        - Outputs: scalar
    .. note::
        Activations :math:`z^{s}` and :math:`z^{t}` must have the same shape.
    .. note::
        The kernel values will add up when there are multiple kernels.
    Examples::
        >>> from tllib.modules.kernels import GaussianKernel
        >>> feature_dim = 1024
        >>> batch_size = 10
        >>> kernels = (GaussianKernel(alpha=0.5), GaussianKernel(alpha=1.), GaussianKernel(alpha=2.))
        >>> loss = MultipleKernelMaximumMeanDiscrepancy(kernels)
        >>> # features from source domain and target domain
        >>> z_s, z_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim)
        >>> output = loss(z_s, z_t)
    """

    def __init__(self, kernels: Sequence[nn.Module], linear: Optional[bool] = False):
        super(MultipleKernelMaximumMeanDiscrepancy, self).__init__()
        self.kernels = kernels
        self.index_matrix = None
        self.linear = linear

    def forward(self, z_s: torch.Tensor, z_t: torch.Tensor) -> torch.Tensor:
        features = torch.cat([z_s, z_t], dim=0)
        batch_size = int(z_s.size(0))
        self.index_matrix = _update_index_matrix(batch_size, self.index_matrix, self.linear).to(z_s.device)


        kernel_matrix = sum([kernel(features) for kernel in self.kernels])  # Add up the matrix of each kernel
        # Add 2 / (n-1) to make up for the value on the diagonal
        # to ensure loss is positive in the non-linear version
        loss = (kernel_matrix * self.index_matrix).sum() + 2. / float(batch_size - 1)

        return loss


def _update_index_matrix(batch_size: int, index_matrix: Optional[torch.Tensor] = None,
                         linear: Optional[bool] = True) -> torch.Tensor:
    r"""
    Update the `index_matrix` which convert `kernel_matrix` to loss.
    If `index_matrix` is a tensor with shape (2 x batch_size, 2 x batch_size), then return `index_matrix`.
    Else return a new tensor with shape (2 x batch_size, 2 x batch_size).
    """
    if index_matrix is None or index_matrix.size(0) != batch_size * 2:
        index_matrix = torch.zeros(2 * batch_size, 2 * batch_size)
        if linear:
            for i in range(batch_size):
                s1, s2 = i, (i + 1) % batch_size
                t1, t2 = s1 + batch_size, s2 + batch_size
                index_matrix[s1, s2] = 1. / float(batch_size)
                index_matrix[t1, t2] = 1. / float(batch_size)
                index_matrix[s1, t2] = -1. / float(batch_size)
                index_matrix[s2, t1] = -1. / float(batch_size)
        else:
            for i in range(batch_size):
                for j in range(batch_size):
                    if i != j:
                        index_matrix[i][j] = 1. / float(batch_size * (batch_size - 1))
                        index_matrix[i + batch_size][j + batch_size] = 1. / float(batch_size * (batch_size - 1))
            for i in range(batch_size):
                for j in range(batch_size):
                    index_matrix[i][j + batch_size] = -1. / float(batch_size * batch_size)
                    index_matrix[i + batch_size][j] = -1. / float(batch_size * batch_size)
    return index_matrix

### CMD

In [None]:
def cmd(src_embed, tgt_embed, n_moments):
    if torch.mean(torch.abs(src_embed) + torch.abs(tgt_embed)) <= 1e-7:
        print("Warning: feature representations tend towards zero. "
              "Consider decreasing 'da_lambda' or using lambda schedule.")

    src_mean = src_embed.mean(dim=0)
    tgt_mean = tgt_embed.mean(dim=0)

    src_centered = src_embed - src_mean
    tgt_centered = tgt_embed - tgt_mean

    first_moment = l2diff(src_mean, tgt_mean)  # start with first moment

    moments_diff_sum = first_moment
    for k in range(2, n_moments + 1):
        moments_diff_sum = moments_diff_sum + moment_diff(src_centered, tgt_centered, k)

    return moments_diff_sum


def l2diff(src, tgt):
    """
    standard euclidean norm. small number added to increase numerical stability.
    """
    return torch.sqrt(torch.sum((src - tgt) ** 2) + 1e-8)


def moment_diff(src, tgt, moment):
    """
    difference between moments
    """
    ss1 = (src ** moment).mean(0)
    ss2 = (tgt ** moment).mean(0)
    return l2diff(ss1, ss2)

### SWD 

In [None]:
import math
def swd(src_embed, tgt_embed, multiplier, p):
    projections = torch.zeros((src_embed.size(1), src_embed.size(1) * multiplier),
                              device=tgt_embed.device).normal_(0, 1)
    projections = projections / torch.norm(projections, p=p, dim=0, keepdim=True)

    # repeat target batch size to be the same size as source
    src_batch_size = src_embed.size(0)
    tgt_batch_size = tgt_embed.size(0)
    batch_size = src_batch_size + tgt_batch_size
    src_repeats = math.ceil(batch_size / src_batch_size)
    tgt_repeats = math.ceil(batch_size / tgt_batch_size)
    src_embed_rep = torch.cat([src_embed] * src_repeats, dim=0)[:batch_size]
    tgt_embed_rep = torch.cat([tgt_embed] * tgt_repeats, dim=0)[:batch_size]

    # project both samples 'num_projections' times
    pr_src = src_embed_rep.mm(projections)
    pr_tgt = tgt_embed_rep.mm(projections)

    # sort the projection results
    pr_sim = torch.sort(pr_src, dim=0)[0]
    pr_meas = torch.sort(pr_tgt, dim=0)[0]
    sliced_wd = torch.pow(pr_sim - pr_meas, p)

    # return mean distance, scaled by batch size
    return sliced_wd.mean()

### CoRAL

In [None]:
def comp_cov(x):
    xm = x - torch.mean(x, dim=0, keepdim=True)
    return xm.T @ xm / (x.size(0) - 1)


def coral(src_embed, tgt_embed):
    src_batch_size = src_embed.size(0)
    tgt_batch_size = tgt_embed.size(0)
    batch_size = src_batch_size + tgt_batch_size
    src_repeats = math.ceil(batch_size / src_batch_size)
    tgt_repeats = math.ceil(batch_size / tgt_batch_size)
    # handle case when source and target are not of same size
    src_embed_rep = torch.cat([src_embed] * src_repeats, dim=0)[:batch_size]
    tgt_embed_rep = torch.cat([tgt_embed] * tgt_repeats, dim=0)[:batch_size]
    d = src_embed_rep.size()[1]
    src_cov = comp_cov(src_embed_rep)
    tgt_cov = comp_cov(tgt_embed_rep)

    # squared matrix frobenius norm
    loss = torch.sum((src_cov - tgt_cov)**2)
    loss = loss / (4 * d * d)
    return loss

## Import dataset include source dataset and target dataset

In [None]:
df_train = pd.read_csv('/kaggle/input/srbh2020-v2/dataset_capec_combine (1).csv')
df_train.head()

In [None]:
# Optional (not effect very much)
# for word tokenizer instead of character tokenizer

df_train['label'] = df_train['category']
df_train = df_train.sample(frac = 1)

df_nor = df_train[(df_train['label'] == 'Normal')].sample(30000)
df_train = df_train[(df_train['label'] != 'Normal')]

df_train = pd.concat([df_nor, df_train])


In [None]:
from sklearn.model_selection import train_test_split
## prepare for training
X_train, X_test, Y_train, Y_test = train_test_split(df_train['text'], df_train['label'],test_size=0.3, stratify=df_train['label'], shuffle = True)
df_train = pd.concat([X_train, Y_train], axis=1)
df_test = pd.concat([X_test, Y_test], axis=1)

In [None]:
df_train['label'].value_counts()

In [None]:
df_transfer = pd.read_csv('/kaggle/input/srbh2020-v2/dataset_capec_transfer (1).csv')
df_transfer.head()

In [None]:

# Optional (not effect very much)
# for word tokenizer instead of character tokenizer
df_transfer['label'] = df_transfer['category']

df_transfer = df_transfer[0:len(df_transfer)//training_parameters['batch_size']*training_parameters['batch_size']]

In [None]:
df_transfer['label'].value_counts()

In [None]:
df_transfer = df_transfer[(df_transfer['label'] != '16 - Dictionary-based Password Attack') & (df_transfer['label'] != 'Normal')]
df_transfer = df_transfer.sample(frac = 1)

df_transfer = df_transfer[(df_transfer['label'] != 'Normal')]

df_transfer = df_transfer[0:len(df_transfer)//training_parameters['batch_size']*training_parameters['batch_size']]
df_train = df_train[0:len(df_train)//training_parameters['batch_size']*training_parameters['batch_size']]

In [None]:
source_dataset = ReviewDataset(df_train)
source_dataloader = DataLoader(dataset = source_dataset, batch_size = training_parameters["batch_size"], shuffle = True, num_workers = 2)

In [None]:
df_transfer['label'].value_counts()

In [None]:
# df_1['label'].value_counts()

In [None]:
target_dataset = ReviewDataset(df_transfer)
target_dataloader = DataLoader(dataset = target_dataset, batch_size = training_parameters["batch_size"], shuffle = True, num_workers = 2)

## Create model

In [None]:
## import torch
import torch.nn as nn
import torch.optim as optim

class DomainAdaptationModel(nn.Module):
    def __init__(self):
        super(DomainAdaptationModel, self).__init__()

        num_labels = config["num_labels"]
        self.bert = AutoModel.from_pretrained('jackaduma/SecBERT') # model that we will use
        self.dropout = nn.Dropout(config["hidden_dropout_prob"])

        self.prj = nn.Linear(config["hidden_size"], config["hidden_size"]//2);
        self.prj2 = nn.Linear(config["hidden_size"]//2, config["hidden_size"]//16);
        self.attack_classifier = nn.Sequential(
            nn.Linear(config["hidden_size"]//16, num_labels),
            nn.LogSoftmax(dim=1),
        )


#       Freeze bert layer
        modules = [self.bert.embeddings, self.bert.encoder.layer[:2]] #Replace value by what you want
        for module in modules:
            for param in module.parameters():
                param.requires_grad = False


    def forward(
          self,
          input_ids=None,
          attention_mask=None,
          token_type_ids=None,
          labels=None,
#           grl_lambda = 1.0,
          ):

        outputs = self.bert(
                input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
            )

#         pooled_output = outputs[1] # For bert-base-uncase
        pooled_output = outputs.pooler_output
        pooled_output = self.dropout(pooled_output)

        pooled_output_prj = self.prj(pooled_output)
        pooled_output_prj2 = self.prj2(pooled_output_prj)
        attack_pred = self.attack_classifier(pooled_output_prj2)

        return attack_pred.to(device), pooled_output_prj, pooled_output_prj2

In [None]:
def compute_accuracy(logits, labels):
    predicted_labels_dict = {
      0: 0,
      1: 0,
      2: 0,
      3: 0,
      4: 0,
      5: 0,
        6:0
    }

    predicted_label = logits.max(dim = 1)[1]

    for pred in predicted_label:
        # print(pred.item())
        predicted_labels_dict[pred.item()] += 1
    acc = (predicted_label == labels).float().mean()

    return acc, predicted_labels_dict

In [None]:
from sklearn.metrics import precision_score, recall_score, confusion_matrix, classification_report,accuracy_score, f1_score
import time

def evaluate(model, dataset = "target", percentage = 80):
    start_time = time.time()
    with torch.no_grad():
        predicted_labels_dict = {
          0: 0,
          1: 0,
          2: 0,
          3: 0,
          4: 0,
          5: 0,
        6:0
        }
        model.eval()
        dev_df = pd.read_csv("/kaggle/input/srbh2020-v2/dataset_capec_" + dataset + " (1).csv")
        dev_df['label'] = dev_df['category']
        dev_df = dev_df[(dev_df['label'] != '16 - Dictionary-based Password Attack') & (dev_df['label'] != 'Normal')]
        dev_df = dev_df[(dev_df['label'] != 'Normal')]
        data_size = dev_df.shape[0] 
        selected_for_evaluation = int(data_size*percentage/100)
        dev_df = dev_df.head(selected_for_evaluation)
        dataset = ReviewDataset(df_transfer)
        dataloader = DataLoader(dataset = dataset, batch_size = training_parameters["batch_size"], shuffle = True, num_workers = 2)

        true_labels = list()
        predicted_label = list()
        for input_ids, attention_mask, token_type_ids, labels in dataloader:
            inputs = {
                "input_ids": input_ids.squeeze(axis=1),
                "attention_mask": attention_mask.squeeze(axis=1),
                "token_type_ids" : token_type_ids.squeeze(axis=1),
                "labels": labels,
            }
            for k, v in inputs.items():
                inputs[k] = v.to(device)
            attack_pred, _, _ = model(**inputs)
            true_labels.extend(inputs['labels'].cpu().numpy())
            predicted_label.extend(attack_pred.max(dim = 1)[1].cpu().numpy())
            _, predicted_labels = compute_accuracy(attack_pred, inputs["labels"])

            for i in range(7):
                  predicted_labels_dict[i] += predicted_labels[i]

        score = f1_score(true_labels,predicted_label,average="macro")
        precision = precision_score(true_labels, predicted_label,average="macro")
        recall = recall_score(true_labels, predicted_label,average="macro")
        report = classification_report(true_labels,predicted_label,digits=4)
        acc= accuracy_score(true_labels, predicted_label)
        #classifaction_report_csv(report,precision,recall,score,0)
        print ('\n clasification report:\n', report)
        print ('F1 score:', score)
        print ('Recall:', recall)
        print ('Precision:', precision)
        print ('Acc:', acc)
        print('Confusion Matrix: \n',confusion_matrix(true_labels, predicted_label))
        print(predicted_labels_dict)
    print("Testing time:", time.time()-start_time)

In [None]:
from sklearn.metrics import precision_score, recall_score, confusion_matrix, classification_report,accuracy_score, f1_score
import time

def evaluate_v2(model, dataset = "target", percentage = 80):
    start_time = time.time()
    with torch.no_grad():
        predicted_labels_dict = {
          0: 0,
          1: 0,
          2: 0,
          3: 0,
          4: 0,
          5: 0,
        6:0
        }
        model.eval()
        dev_df = pd.read_csv("/kaggle/input/srbh2020-v2/dataset_capec_" + dataset + " (1).csv")
        dev_df['label'] = dev_df['category']
        dev_df = dev_df[(dev_df['label'] != '16 - Dictionary-based Password Attack') & (dev_df['label'] != 'Normal')]
        dev_df = dev_df[(dev_df['label'] != 'Normal')]
        data_size = dev_df.shape[0] 
        selected_for_evaluation = int(data_size*percentage/100)
        dev_df = dev_df.head(selected_for_evaluation)
        dataset = ReviewDataset(df_transfer[:3000])
        dataloader = DataLoader(dataset = dataset, batch_size = training_parameters["batch_size"], shuffle = True, num_workers = 2)

        true_labels = list()
        predicted_label = list()
        for input_ids, attention_mask, token_type_ids, labels in dataloader:
            inputs = {
                "input_ids": input_ids.squeeze(axis=1),
                "attention_mask": attention_mask.squeeze(axis=1),
                "token_type_ids" : token_type_ids.squeeze(axis=1),
                "labels": labels,
            }
            for k, v in inputs.items():
                inputs[k] = v.to(device)
            attack_pred, _, _ = model(**inputs)
            true_labels.extend(inputs['labels'].cpu().numpy())
            predicted_label.extend(attack_pred.max(dim = 1)[1].cpu().numpy())
            _, predicted_labels = compute_accuracy(attack_pred, inputs["labels"])

            for i in range(7):
                  predicted_labels_dict[i] += predicted_labels[i]

        score = f1_score(true_labels,predicted_label,average="macro")
        precision = precision_score(true_labels, predicted_label,average="macro")
        recall = recall_score(true_labels, predicted_label,average="macro")
        report = classification_report(true_labels,predicted_label,digits=4)
        acc= accuracy_score(true_labels, predicted_label)
        #classifaction_report_csv(report,precision,recall,score,0)
        print ('\n clasification report:\n', report)
        print ('F1 score:', score)
        print ('Recall:', recall)
        print ('Precision:', precision)
        print ('Acc:', acc)
        print('Confusion Matrix: \n',confusion_matrix(true_labels, predicted_label))
        print(predicted_labels_dict)
    print("Testing time:", time.time()-start_time)

## Training

In [None]:
%%time
import torch.nn.functional as F
import time

# Function to initialize the model and optimizer
def initialize_model_and_optimizer():
    model = DomainAdaptationModel()  # Initialize the domain adaptation model
    model.to(device)  # Move the model to the specified device (e.g., GPU)
    optimizer = optim.Adam(model.parameters(), training_parameters["learning_rate"])  # Use Adam optimizer
    return model, optimizer

# Function to initialize the MK-MMD loss
def initialize_mkmmd_loss():
    return MultipleKernelMaximumMeanDiscrepancy(
        kernels=[GaussianKernel(alpha=2 ** k) for k in range(-3, 2)],  # Define Gaussian kernels
        linear=True
    )

# Function to prepare inputs for the model
def prepare_inputs(batch, device):
    inputs = {
        "input_ids": batch[0].squeeze(axis=1),  # Squeeze input IDs
        "attention_mask": batch[1].squeeze(axis=1),  # Squeeze attention mask
        "token_type_ids": batch[2].squeeze(axis=1),  # Squeeze token type IDs
        "labels": batch[3],  # Labels
    }
    for k, v in inputs.items():
        inputs[k] = v.to(device)  # Move inputs to the specified device
    return inputs

# Function to train the model for one epoch
def train_one_epoch(model, optimizer, mkmmd_loss, source_dataloader, target_dataloader, max_batches, epoch_idx):
    source_iterator = iter(source_dataloader)  # Iterator for source dataloader
    target_iterator = iter(target_dataloader)  # Iterator for target dataloader
    mean_clf, mean_da_mmd, mean_total, mean_grl = 0., 0., 0., 0.  # Initialize metrics
    plot_da_mmd, plot_clf, plot_total, plot_grl = [], [], [], []  # Initialize plots

    for batch_idx in range(max_batches):
        # Compute gradient reversal lambda
        p = float(batch_idx + epoch_idx * max_batches) / (training_parameters["epochs"] * max_batches)
        grl_lambda = 0.5 * (2. / (1. + np.exp(-7 * p)) - 1)
        mean_grl += grl_lambda
        grl_lambda = torch.tensor(grl_lambda)

        model.train()  # Set model to training mode
        mkmmd_loss.train()  # Set MK-MMD loss to training mode
        optimizer.zero_grad()  # Zero the gradients

        # Source dataset training update
        source_batch = next(source_iterator)  # Get the next batch from source dataloader
        source_inputs = prepare_inputs(source_batch, device)  # Prepare inputs
        attack_pred, pooled_output_prj_source, _ = model(**source_inputs)  # Forward pass
        loss_s_attack = torch.nn.NLLLoss()(attack_pred, source_inputs["labels"])  # Compute classification loss

        # Target dataset training update
        target_batch = next(target_iterator)  # Get the next batch from target dataloader
        target_inputs = prepare_inputs(target_batch, device)  # Prepare inputs
        _, pooled_output_prj_target, _ = model(**target_inputs)  # Forward pass

        # Compute transfer loss using MK-MMD
        transfer_loss = mkmmd_loss(pooled_output_prj_source, pooled_output_prj_target)
        loss = loss_s_attack + transfer_loss * grl_lambda  # Total loss

        # Update metrics
        mean_clf += loss_s_attack.item()
        mean_da_mmd += transfer_loss.item()
        mean_total += loss.item()

        # Backpropagation and optimizer step
        loss.backward()
        optimizer.step()

        # Log metrics every 200 batches
        if batch_idx % 200 == 0 and batch_idx != 0:
            plot_da_mmd.append(mean_da_mmd / 200)
            plot_clf.append(mean_clf / 200)
            plot_total.append(mean_total / 200)
            plot_grl.append(mean_grl / 200)
            mean_clf, mean_da_mmd, mean_total, mean_grl = 0., 0., 0., 0.

    return plot_da_mmd, plot_clf, plot_total, plot_grl

# Function to train the model
def train_model():
    model, optimizer = initialize_model_and_optimizer()  # Initialize model and optimizer
    mkmmd_loss = initialize_mkmmd_loss()  # Initialize MK-MMD loss
    max_batches = min(len(source_dataloader), len(target_dataloader))  # Determine the maximum number of batches
    start_time = time.time()  # Start timing

    for epoch_idx in range(training_parameters["epochs"]):  # Loop over epochs
        # Train for one epoch
        plot_da_mmd, plot_clf, plot_total, plot_grl = train_one_epoch(
            model, optimizer, mkmmd_loss, source_dataloader, target_dataloader, max_batches, epoch_idx
        )
        print(f"Epoch: {epoch_idx}")  # Print epoch number
        evaluate_v2(model, dataset="combine", percentage=10)  # Evaluate the model

    print("Training time:", time.time() - start_time)  # Print total training time

# Start training
train_model()


In [None]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.interpolate import make_interp_spline


def mean_blocks(data, block_size):
    return [np.mean(data[i:i+block_size]) for i in range(0, len(data), block_size)]

y1 = np.array(plot_da)
y2 = np.array(plot_clf)
y3 = np.array(plot_total)
y4 = np.array(plot_grl)


plt.plot(y4)
plt.title("GRL")
plt.show()
plt.plot(y2)
plt.title("MMD")
plt.legend(["transfer loss","classification loss", "total_loss", "grl"], loc ="upper right")
plt.show()


In [None]:
torch.save(model.state_dict(), os.path.join(training_parameters["output_folder"], "epoch_" + str(n_epochs)  +  training_parameters["output_file"] ))

In [None]:
%%time
evaluate(model, dataset = "transfer", percentage = 100)

## Get accuracy on source and target dataset

In [None]:
evaluate(model, dataset = "combine", percentage = 10)
