In [1]:
!rm -rf *
!curl 'https://drive.usercontent.google.com/download?id=19DPObbiUbzGFEbCoPAyixrv_JT5QCQXE&export=download&authuser=0&confirm=t&uuid=7869aa1b-8a2e-4169-a9ee-f1f2d7311078&at=AENtkXYJgijttsPeTTrrX2CrUGaz%3A1730284122447' > dataset.zip
!unzip dataset.zip
!rm -rf dataset.zip

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  287M  100  287M    0     0  65.8M      0  0:00:04  0:00:04 --:--:-- 65.8M
Archive:  dataset.zip
   creating: dataset/
   creating: dataset/part_two_dataset/
  inflating: dataset/.DS_Store       
  inflating: __MACOSX/dataset/._.DS_Store  
  inflating: dataset/README.md       
  inflating: __MACOSX/dataset/._README.md  
   creating: dataset/part_one_dataset/
  inflating: dataset/part_two_dataset/.DS_Store  
  inflating: __MACOSX/dataset/part_two_dataset/._.DS_Store  
   creating: dataset/part_two_dataset/train_data/
   creating: dataset/part_two_dataset/eval_data/
  inflating: dataset/part_one_dataset/.DS_Store  
  inflating: __MACOSX/dataset/part_one_dataset/._.DS_Store  
   creating: dataset/part_one_dataset/train_data/
   creating: dataset/part_one_dataset/eval_data/
  inflating: dataset/part_two_dataset/train_data/6_train_

In [7]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import os, tarfile
import numpy as np
from sklearn.mixture import GaussianMixture
from collections import defaultdict

# 📁 Define data path
DATASET_PATH = '/content/dataset/part_one_dataset/train_data'

In [10]:
def load_dataset(index):
    file_path = os.path.join(DATASET_PATH, f'{index}_train_data.tar.pth')

    # Load ignoring strict weight-only enforcement (if PyTorch >= 2.1.0)
    try:
        data_dict = torch.load(file_path, map_location='cpu', weights_only=False)
    except TypeError:
        # fallback for older torch versions without weights_only
        data_dict = torch.load(file_path, map_location='cpu')

    return data_dict.get('data', None), data_dict.get('labels', None)

In [11]:
class Encoder(nn.Module):
    def __init__(self, input_dim=1024, latent_dim=128):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Linear(512, latent_dim)
        )
    def forward(self, x):
        return self.encoder(x)

class Classifier(nn.Module):
    def __init__(self, latent_dim=128, num_classes=10):
        super().__init__()
        self.fc = nn.Linear(latent_dim, num_classes)
    def forward(self, z):
        return self.fc(z)

In [12]:
def fit_gmm(latents, labels, num_classes=10):
    gmm_dict = {}
    for cls in range(num_classes):
        cls_latents = latents[labels == cls]
        gmm = GaussianMixture(n_components=1, covariance_type='full')
        gmm.fit(cls_latents)
        gmm_dict[cls] = gmm
    return gmm_dict

In [13]:
def sliced_wasserstein_distance(P, Q, num_projections=50):
    P, Q = P.detach().cpu().numpy(), Q.detach().cpu().numpy()
    d = P.shape[1]
    swd = 0.0
    for _ in range(num_projections):
        proj = np.random.normal(size=(d,))
        proj /= np.linalg.norm(proj)
        proj_P = np.dot(P, proj)
        proj_Q = np.dot(Q, proj)
        proj_P.sort(), proj_Q.sort()
        swd += np.mean(np.abs(proj_P - proj_Q))
    return swd / num_projections

In [14]:
def generate_pseudo_data(gmm_dict, classifier, num_samples=100, threshold=0.7):
    pseudo_X, pseudo_Y = [], []
    classifier.eval()
    for cls, gmm in gmm_dict.items():
        z = gmm.sample(num_samples)[0]
        z_tensor = torch.tensor(z, dtype=torch.float32)
        with torch.no_grad():
            logits = classifier(z_tensor)
            probs = torch.softmax(logits, dim=1)
            confs, preds = torch.max(probs, dim=1)
            mask = confs > threshold
            pseudo_X.append(z_tensor[mask])
            pseudo_Y.append(preds[mask])
    return torch.cat(pseudo_X), torch.cat(pseudo_Y)

In [15]:
def train_on_domain(encoder, classifier, data, gmm_dict, replay_buffer, epochs=5):
    encoder.train(), classifier.train()
    optimizer = torch.optim.Adam(list(encoder.parameters()) + list(classifier.parameters()), lr=1e-3)

    for epoch in range(epochs):
        # Forward pass for target domain
        x_t = torch.tensor(data, dtype=torch.float32)
        z_t = encoder(x_t)

        # SWD between current and internal distribution
        z_gmm, _ = generate_pseudo_data(gmm_dict, classifier, num_samples=100)
        swd = sliced_wasserstein_distance(z_t, z_gmm)

        loss = torch.tensor(swd, requires_grad=True)

        # Experience replay if available
        if replay_buffer:
            x_replay, y_replay = zip(*replay_buffer)
            x_replay = torch.stack(x_replay)
            y_replay = torch.tensor(y_replay)
            z_replay = encoder(x_replay)
            pred_replay = classifier(z_replay)
            loss += nn.CrossEntropyLoss()(pred_replay, y_replay)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")

In [16]:
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

encoder = Encoder(input_dim=1024, latent_dim=64)
classifier = Classifier(latent_dim=64, num_classes=10)

replay_buffer = []
all_latents, all_labels, domain_tags = [], [], []

# Initial training on Domain 0 (labeled)
src_data, src_labels = load_dataset(0)
x = torch.tensor(src_data, dtype=torch.float32)
y = torch.tensor(src_labels)
optimizer = torch.optim.Adam(list(encoder.parameters()) + list(classifier.parameters()), lr=1e-3)
for _ in range(10):
    encoder.train(), classifier.train()
    z = encoder(x)
    logits = classifier(z)
    loss = nn.CrossEntropyLoss()(logits, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# Fit GMM on latent space
with torch.no_grad():
    z = encoder(x)
    gmm_dict = fit_gmm(z.numpy(), src_labels)

# Add representative samples to replay buffer (Mean-of-Features)
for cls in range(10):
    cls_mask = (y == cls)
    z_cls = z[cls_mask]
    x_cls = x[cls_mask]
    dists = torch.norm(z_cls - z_cls.mean(0), dim=1)
    topk = torch.topk(-dists, k=min(5, len(dists))).indices
    for i in topk:
        replay_buffer.append((x_cls[i], int(cls)))

# UDA for domains 1–9
for d in range(1, 10):
    print(f"\n➡️ Adapting to domain {d}")
    data, _ = load_dataset(d)
    train_on_domain(encoder, classifier, data, gmm_dict, replay_buffer, epochs=5)

    # Update GMM
    with torch.no_grad():
        x_t = torch.tensor(data, dtype=torch.float32)
        z_t = encoder(x_t)
        pseudo_labels = torch.argmax(classifier(z_t), dim=1)
        gmm_dict = fit_gmm(z_t.numpy(), pseudo_labels.numpy())

        # Store for UMAP
        all_latents.append(z_t)
        all_labels.append(pseudo_labels)
        domain_tags += [d] * len(pseudo_labels)

FileNotFoundError: [Errno 2] No such file or directory: '/content/dataset/part_one_dataset/train_data/0_train_data.tar.pth'