In [1]:
!rm -rf *
!curl 'https://drive.usercontent.google.com/download?id=19DPObbiUbzGFEbCoPAyixrv_JT5QCQXE&export=download&authuser=0&confirm=t&uuid=7869aa1b-8a2e-4169-a9ee-f1f2d7311078&at=AENtkXYJgijttsPeTTrrX2CrUGaz%3A1730284122447' > dataset.zip
!unzip dataset.zip
!rm -rf dataset.zip

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  287M  100  287M    0     0  48.8M      0  0:00:05  0:00:05 --:--:-- 70.0M
Archive:  dataset.zip
   creating: dataset/
   creating: dataset/part_two_dataset/
  inflating: dataset/.DS_Store       
  inflating: __MACOSX/dataset/._.DS_Store  
  inflating: dataset/README.md       
  inflating: __MACOSX/dataset/._README.md  
   creating: dataset/part_one_dataset/
  inflating: dataset/part_two_dataset/.DS_Store  
  inflating: __MACOSX/dataset/part_two_dataset/._.DS_Store  
   creating: dataset/part_two_dataset/train_data/
   creating: dataset/part_two_dataset/eval_data/
  inflating: dataset/part_one_dataset/.DS_Store  
  inflating: __MACOSX/dataset/part_one_dataset/._.DS_Store  
   creating: dataset/part_one_dataset/train_data/
   creating: dataset/part_one_dataset/eval_data/
  inflating: dataset/part_two_dataset/train_data/6_train_

New implementation

In [46]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random

# Set random seeds for reproducibility
torch.manual_seed(0)
random.seed(0)

class CNNEncoder(nn.Module):
    def __init__(self, latent_dim=64):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),  # (B, 32, 32, 32)
            nn.ReLU(),
            nn.MaxPool2d(2),                                       # (B, 32, 16, 16)

            nn.Conv2d(32, 64, kernel_size=3, padding=1),           # (B, 64, 16, 16)
            nn.ReLU(),
            nn.MaxPool2d(2),                                       # (B, 64, 8, 8)

            nn.Conv2d(64, 128, kernel_size=3, padding=1),          # (B, 128, 8, 8)
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))                           # (B, 128, 1, 1)
        )
        self.fc = nn.Linear(128, latent_dim)                        # Final latent vector

    def forward(self, x):
        x = self.encoder(x)  # (B, 128, 1, 1)
        x = x.view(x.size(0), -1)  # Flatten to (B, 128)
        return self.fc(x)          # (B, latent_dim)

# Classification head: maps latent to class logits
class Classifier(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(Classifier, self).__init__()
        self.fc = nn.Linear(latent_dim, num_classes)
    def forward(self, z):
        return self.fc(z)

# Utility: compute class prototypes (mean of features) given features and labels
def compute_prototypes(features, labels, num_classes):
    prototypes = []
    for c in range(num_classes):
        class_feats = features[labels == c]
        if len(class_feats) > 0:
            prototypes.append(class_feats.mean(dim=0))
        else:
            prototypes.append(torch.zeros(features.shape[1]))
    return torch.stack(prototypes)  # shape: (num_classes, latent_dim)

# Utility: compute Euclidean distance between features and prototypes
def l2_distances(features, prototypes):
    # features: (batch, dim), prototypes: (num_classes, dim)
    # Returns (batch, num_classes) distances
    # ||f - p||^2 = ||f||^2 + ||p||^2 - 2 f·p
    f_sq = (features**2).sum(dim=1, keepdim=True)
    p_sq = (prototypes**2).sum(dim=1)
    dist = f_sq + p_sq.unsqueeze(0) - 2 * features @ prototypes.t()
    return dist

# Compute Sliced Wasserstein Distance (SWD) between two sets of vectors
def sliced_wasserstein_distance(X, Y, num_projections=50, device='cpu'):
    # X, Y: tensors of shape (n_samples, feature_dim)
    # Sample random directions on the unit sphere
    d = X.size(1)
    swd = 0.0
    for _ in range(num_projections):
        # Draw a random direction vector
        theta = torch.randn(d, device=device)
        theta = theta / theta.norm()
        proj_X = X @ theta  # (n,)
        proj_Y = Y @ theta
        # Sort projections
        proj_X, _ = torch.sort(proj_X)
        proj_Y, _ = torch.sort(proj_Y)
        # 1D Wasserstein: L2 distance of sorted projections
        swd += torch.mean((proj_X - proj_Y)**2)
    return swd / num_projections

# GMM internal distribution: fit per-class Gaussian parameters on latent features
class LatentGMM:
    def __init__(self, num_classes, latent_dim):
        self.num_classes = num_classes
        self.means = torch.zeros(num_classes, latent_dim)
        self.covariances = torch.zeros(num_classes, latent_dim, latent_dim)
        self.weights = torch.zeros(num_classes)
    def fit(self, features, labels):
        N = len(labels)
        for c in range(self.num_classes):
            idx = (labels == c).nonzero(as_tuple=True)[0]
            Nc = len(idx)
            if Nc > 0:
                feats_c = features[idx]
                self.means[c] = feats_c.mean(dim=0)
                # covariance (latent_dim x latent_dim)
                centered = feats_c - self.means[c]
                cov = centered.t() @ centered / (Nc + 1e-6)
                self.covariances[c] = cov + 1e-6 * torch.eye(centered.size(1))
                self.weights[c] = Nc / N
            else:
                self.means[c].zero_()
                self.covariances[c] = torch.eye(features.size(1))
                self.weights[c] = 1.0 / self.num_classes
    def sample(self, n_samples):
        # Sample from the GMM: choose component according to weights, then Gaussian
        if self.num_classes == 1:
            comp = torch.zeros(n_samples, dtype=torch.long)
        else:
            comp = torch.multinomial(self.weights, num_samples=n_samples, replacement=True)
        Z = torch.zeros(n_samples, self.means.size(1))
        for i, c in enumerate(comp):
            mean = self.means[c]
            cov = self.covariances[c]
            # Sample from multivariate normal (using Cholesky)
            L = torch.cholesky(cov + 1e-6 * torch.eye(cov.size(0)))
            z = mean + L @ torch.randn(cov.size(0))
            Z[i] = z
        return Z


In [47]:
from torch.utils.data import Dataset
import torch
import os

# Custom Dataset class for loading train and eval domains
class CustomDataset(Dataset):
    def __init__(self, file_path, transform=None, labeled=True):
        self.data_dict = torch.load(file_path, map_location='cpu', weights_only=False)
        self.data = self.data_dict['data']  # tensor of shape (N, C, H, W) or (N, F)
        self.transform = transform
        self.labeled = labeled
        if labeled:
            self.targets = self.data_dict.get('targets', None)  # should be tensor of shape (N,)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img = self.data[idx]
        if isinstance(img, np.ndarray):
            img = torch.from_numpy(img)
        img = img.permute(2,0,1).float() / 255.0
        if self.transform:
            img = self.transform(img)
        if self.labeled:
            label = self.targets[idx]
            return img, label
        else:
            return img

# Utility function to load dataset object
def load_dataset(domain_idx, train=True, transform=None):
    base = "dataset/part_one_dataset"
    subdir = "train_data" if train else "eval_data"
    filename = f"{domain_idx}_{'train' if train else 'eval'}_data.tar.pth"
    path = os.path.join(base, subdir, filename)
    labeled = train and domain_idx == 1 or not train  # Only domain 1 is labeled in training, all eval are labeled
    return CustomDataset(path, transform=transform, labeled=labeled)


In [48]:
import torch.nn.functional as F

# Assuming encoder and classifier are already trained on D1 and available
# Also assume latent_D1 and y1_train are available as torch.Tensors

# Define GMM class from previous code (reused)
class LatentGMM:
    def __init__(self, num_classes, latent_dim):
        self.num_classes = num_classes
        self.latent_dim = latent_dim
        self.means = torch.zeros(num_classes, latent_dim)
        self.covariances = torch.stack([torch.eye(latent_dim) for _ in range(num_classes)])
        self.weights = torch.ones(num_classes) / num_classes

    def fit(self, features, labels):
        N = len(labels)
        for c in range(self.num_classes):
            idx = (labels == c).nonzero(as_tuple=True)[0]
            if len(idx) > 0:
                feats_c = features[idx]
                mean = feats_c.mean(dim=0)
                cov = (feats_c - mean).T @ (feats_c - mean) / (len(idx) + 1e-6)
                self.means[c] = mean
                self.covariances[c] = cov + 1e-6 * torch.eye(self.latent_dim)
                self.weights[c] = len(idx) / N
            else:
                self.means[c] = torch.zeros(self.latent_dim)
                self.covariances[c] = torch.eye(self.latent_dim)
                self.weights[c] = 1.0 / self.num_classes
        self.weights /= self.weights.sum()

    def sample(self, num_samples):
        components = torch.multinomial(self.weights, num_samples=num_samples, replacement=True)
        Z = torch.zeros((num_samples, self.latent_dim))
        for i, c in enumerate(components):
            mean = self.means[c]
            cov = self.covariances[c]
            L = torch.linalg.cholesky(cov)
            z = mean + L @ torch.randn(self.latent_dim)
            Z[i] = z
        return Z

def generate_pseudo_data(gmm, classifier, tau=0.8, num_samples=1000):
    """
    Draw latent samples from the internal GMM and assign pseudo-labels
    using the classifier. Return confident latent vectors and their labels.
    """
    Z = gmm.sample(num_samples)  # (num_samples, latent_dim)
    with torch.no_grad():
        logits = classifier(Z)
        probs = F.softmax(logits, dim=1)
        conf, pred = torch.max(probs, dim=1)
        mask = conf > tau
    return Z[mask], pred[mask]


In [49]:
print(x.shape)

torch.Size([64, 3, 32, 32])


In [58]:
from torch.utils.data import DataLoader

# Hyperparameters
num_epochs = 5
batch_size = 64
lambda_swd = 1.0
tau = 0.8
memory_size_per_class = 20

# Initialize encoder and classifier
input_dim = 1024  # Set this based on your data
latent_dim = 64
num_classes = 10

encoder = CNNEncoder(latent_dim=latent_dim)
classifier = Classifier(latent_dim=latent_dim, num_classes=num_classes)
optimizer = optim.Adam(list(encoder.parameters()) + list(classifier.parameters()), lr=1e-3)

# === Initial training on D1 === #
d1_dataset = load_dataset(1, train=True)
d1_loader = DataLoader(d1_dataset, batch_size=batch_size, shuffle=True)

for epoch in range(3):
    for x, y in d1_loader:
        z = encoder(x)
        logits = classifier(z)
        loss = F.cross_entropy(logits, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# Extract D1 features and labels for GMM and memory
encoder.eval()
with torch.no_grad():
    feats, labels = [], []
    for x, y in d1_loader:
        feats.append(encoder(x))
        labels.append(y)
feats = torch.cat(feats)
labels = torch.cat(labels)

# Initialize prototypes and memory using MoF
prototypes = compute_prototypes(feats, labels, num_classes)
memory_X, memory_y = [], []
for c in range(num_classes):
    idx = (labels == c).nonzero(as_tuple=True)[0]
    if len(idx) > 0:
        dists = torch.norm(feats[idx] - prototypes[c], dim=1)
        top_idx = idx[torch.argsort(dists)[:memory_size_per_class]]
        imgs = [
          torch.from_numpy(img).permute(2, 0, 1).float() / 255.0
          for img in dt_dataset.data[top_idx]
        ]
        memory_X.append(torch.stack(imgs))
        memory_y.append(labels[top_idx])
memory_X = torch.cat(memory_X)
memory_y = torch.cat(memory_y)

# Initialize GMM
gmm = LatentGMM(num_classes=num_classes, latent_dim=latent_dim)
gmm.fit(feats, labels)

# === Continual adaptation from domains 2 to 10 === #
for t in range(2, 11):
    print(f"\n🌀 Adapting to domain {t}")
    # Load current target domain
    dt_dataset = load_dataset(t, train=True)
    dt_loader = DataLoader(dt_dataset, batch_size=batch_size, shuffle=True)

    # Generate pseudo-latent samples
    Z_pseudo, y_pseudo = generate_pseudo_data(gmm, classifier, tau=tau, num_samples=1000)

    # Training loop
    for epoch in range(num_epochs):
        for x_t in dt_loader:
            if isinstance(x_t, tuple):  # skip labels if they exist
                x_t = x_t[0]

            # Get memory & pseudo batches
            mem_idx = random.sample(range(len(memory_X)), min(batch_size, len(memory_X)))
            pseudo_idx = random.sample(range(len(Z_pseudo)), min(batch_size, len(Z_pseudo)))
            X_mem = memory_X[mem_idx]; y_mem = memory_y[mem_idx]
            Z_mem = Z_pseudo[pseudo_idx]; y_pseudo_batch = y_pseudo[pseudo_idx]
            if Z_mem.size(0) == 0 or feats_t.size(0) == 0:
                continue  # skip this batch

            # Forward passes
            feats_t = encoder(x_t)
            feats_mem = encoder(X_mem)
            logits_mem = classifier(feats_mem)
            logits_pseudo = classifier(Z_mem)

            # Losses
            loss_ce = F.cross_entropy(logits_mem, y_mem) + F.cross_entropy(logits_pseudo, y_pseudo_batch)
            swd = sliced_wasserstein_distance(feats_t, Z_mem)
            loss = loss_ce + lambda_swd * swd

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    # === Update memory buffer === #
    encoder.eval()
    all_feats, all_preds = [], []
    with torch.no_grad():
        for x_t in DataLoader(dt_dataset, batch_size=batch_size):
            if isinstance(x_t, tuple):  # if data includes labels
                x_t = x_t[0]
            feats = encoder(x_t)
            logits = classifier(feats)
            preds = torch.argmax(logits, dim=1)
            all_feats.append(feats)
            all_preds.append(preds)
    all_feats = torch.cat(all_feats)
    all_preds = torch.cat(all_preds)

    prototypes = compute_prototypes(encoder(memory_X), memory_y, num_classes)
    for c in range(num_classes):
        idx_c = (all_preds == c).nonzero(as_tuple=True)[0]
        if len(idx_c) > 0:
            dists = torch.norm(all_feats[idx_c] - prototypes[c], dim=1)
            top_idx = idx_c[torch.argsort(dists)[:memory_size_per_class]]

            new_imgs = [
                torch.from_numpy(img).permute(2, 0, 1).float() / 255.0
                for img in dt_dataset.data[top_idx]
            ]
            new_imgs_tensor = torch.stack(new_imgs)

            memory_X = torch.cat([memory_X, new_imgs_tensor])
            memory_y = torch.cat([memory_y, torch.full((len(top_idx),), c, dtype=torch.long)])

    # Refit GMM on current memory buffer
    with torch.no_grad():
        feats_mem_total = encoder(memory_X)
    gmm.fit(feats_mem_total, memory_y)



🌀 Adapting to domain 2

🌀 Adapting to domain 3

🌀 Adapting to domain 4

🌀 Adapting to domain 5

🌀 Adapting to domain 6

🌀 Adapting to domain 7

🌀 Adapting to domain 8

🌀 Adapting to domain 9

🌀 Adapting to domain 10


In [60]:
from torch.utils.data import DataLoader

# Prepare accuracy matrix
accuracies = torch.zeros(10, 10)  # rows: after training domain i, cols: test on eval domain j

# Device (CPU or CUDA)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Send models to device
encoder.to(device)
classifier.to(device)

# Evaluation function using DataLoader
def evaluate(encoder, classifier, dataset):
    encoder.eval()
    classifier.eval()
    all_preds, all_labels = [], []

    loader = DataLoader(dataset, batch_size=64)
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            z = encoder(x)
            logits = classifier(z)
            preds = logits.argmax(dim=1)
            all_preds.append(preds)
            all_labels.append(y)

    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    return (all_preds == all_labels).float().mean().item()

# Evaluate after domain 1 training
for j in range(1, 11):
    eval_dataset = load_dataset(j, train=False)  # labeled eval data
    acc = evaluate(encoder, classifier, eval_dataset)
    accuracies[0, j - 1] = acc
    print(f"Eval after D1 → Eval D{j}: {acc:.2%}")

# Evaluate after domains 2–10 adaptation
for i in range(2, 11):
    print(f"\n🔍 Evaluating after training up to domain D{i}")
    for j in range(1, 11):
        eval_dataset = load_dataset(j, train=False)
        acc = evaluate(encoder, classifier, eval_dataset)
        accuracies[i - 1, j - 1] = acc
        print(f"Eval after D{i} → Eval D{j}: {acc:.2%}")

# Print full matrix
print("\n📊 Accuracy matrix (rows: after D1–D10, cols: Eval D1–D10):")
print((accuracies * 100).round(decimals=2))


Eval after D1 → Eval D1: 18.04%
Eval after D1 → Eval D2: 17.72%
Eval after D1 → Eval D3: 16.88%
Eval after D1 → Eval D4: 18.28%
Eval after D1 → Eval D5: 17.08%
Eval after D1 → Eval D6: 18.04%
Eval after D1 → Eval D7: 17.72%
Eval after D1 → Eval D8: 16.80%
Eval after D1 → Eval D9: 17.28%
Eval after D1 → Eval D10: 16.16%

🔍 Evaluating after training up to domain D2
Eval after D2 → Eval D1: 18.04%
Eval after D2 → Eval D2: 17.72%
Eval after D2 → Eval D3: 16.88%
Eval after D2 → Eval D4: 18.28%
Eval after D2 → Eval D5: 17.08%


KeyboardInterrupt: 