### Task 1

#### 1. Basic Concepts
1. The purpose of using dataset distillation in this paper is to reduce the training costs while maintaining the high performance on various machine learning tasks. The authors introduce Dataset Distillation with Attention Maching (DataDAM) to condense large datasets into smaller synthetic dataset that retain the critical information, allowing models trained on the synthetic set to achieve similar accuracy as those trained on the full dataset.
2.  The advantages are: (page 2)
- Efficient end-to-end dataset distillation: This highlights the ability of DataDAM to closely approximate the distribution of the real dataset while keeping **computational costs low**.
- Improved accuracy and scalability: DataDAM demonstrate the performance across multiple benchmark dataset and reduces the training costs by up to 100x, while also allowing for cross-architecture generation. This makes it more scalable and flexible for real-world application.
- Enhancement of downstream application: DATADAM's distilled data improves memory efficiency in continual learning tasks and accelerates neural architectures search(NAS) by providing a more representative proxy dataset, enabling a faster and more efficient learning process.  
3. The novelty includes: (page 2)
- Multiple Randomly Initialized DNNs: DataDAM uses multiple randomly initialized deep neural networks to extract meaningful representations from both real and synthetic datasets, which is different from methods that rely on pre-trained models
- Spatial attention matching (SAM): The SAM module align the most discriminative feature maps from real and synthetic datasets, reducing the gap between the dataset.
- Last-Layer Feature Alignment: It reduces disparities in the last-layer feature distributions between the real and synthetic datasets by using a complementary loss as a regularizer, ensuring high-level abstract representations are similar.
- Bias-Free Synthetic Data: The synthetic data generated by DataDAM does not introduce any bias, which is a significant improvement over prior methods, ensuring better generalization and performance.
4. The methodology of DataDAM is centered on efficiently distilling datasets through attention matching: (page 4)
- Initialization of Synthetic Dataset: The process starts by initializing a synthetic dataset, which can be done through random noise or by sampling real data.
- Feature Extraction: Real and synthetic datasets are passed through randomly initialized deep neural networks, and features are extracted at multiple layers.
- Spatial Attention Matching (SAM): Attention maps are computed for each layer, excluding the final layer. These attention maps focus on the most discriminative regions of the input image. 
- Loss Functions:
    - SAM Loss (LSAM): This loss minimizes the distance between attention maps of real and synthetic datasets across layers.
    - Maximum Mean Discrepancy Loss (MMD): This complementary loss aligns the last-layer feature distributions of the two datasets, ensuring the high-level abstract information is captured.
- Optimization: The synthetic dataset is optimized using a combination of the SAM loss and LMMD loss to minimize the difference between real and synthetic data.
5. (page 8)
- Continual Learning: DataDAM’s ability to condense datasets efficiently makes it highly useful in continual learning scenarios, where a model must learn incrementally while preventing catastrophic forgetting. By using the distilled datasets as a replay buffer, DataDAM can significantly improve memory efficiency and performance in incremental learning tasks.
- Neural Architecture Search (NAS): The synthetic datasets generated by DataDAM can serve as proxies in NAS tasks, allowing faster evaluation of model architectures. This leads to a significant reduction in computational costs and time during the model search process, making NAS more feasible in real-world applications.

#### 2. Data Distillation Learning - MHIST

In [47]:
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import CosineAnnealingLR

In [42]:
import import_ipynb
from utils import get_network

In [43]:
train_folder = 'mhist_dataset/train'
test_folder = 'mhist_dataset/test'

transform = transforms.Compose([
    transforms.ToTensor()
])

train_dataset = datasets.ImageFolder(root=train_folder, transform=transform)
test_dataset = datasets.ImageFolder(root=test_folder, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)


In [44]:
model = get_network(model='ConvNetD7', channel=3, num_classes=2, im_size=(224, 224)) 
optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = CosineAnnealingLR(optimizer, T_max=20)
criterion = torch.nn.CrossEntropyLoss()

In [45]:
for epoch in range(20):
    model.train()
    
    running_loss = 0.0
    for inputs, labels in train_loader:
        labels = labels.long()
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)
    scheduler.step()
    
    print(f"Epoch [{epoch+1}/20], Loss: {running_loss / len(train_loader.dataset)}")

Epoch [1/20], Loss: 0.5870753878012471
Epoch [2/20], Loss: 0.5267189613155935
Epoch [3/20], Loss: 0.4576146731431457
Epoch [4/20], Loss: 0.45241405881684404
Epoch [5/20], Loss: 0.3696940329020051
Epoch [6/20], Loss: 0.34221216551188766
Epoch [7/20], Loss: 0.29201762159665423
Epoch [8/20], Loss: 0.22472330120103112
Epoch [9/20], Loss: 0.17314397242562524
Epoch [10/20], Loss: 0.11898652291503446
Epoch [11/20], Loss: 0.07951179893194944
Epoch [12/20], Loss: 0.04623789550586679
Epoch [13/20], Loss: 0.027705157438571427
Epoch [14/20], Loss: 0.016827880606569093
Epoch [15/20], Loss: 0.012488331228237727
Epoch [16/20], Loss: 0.010448694467844293
Epoch [17/20], Loss: 0.009326749911167841
Epoch [18/20], Loss: 0.008754224179034261
Epoch [19/20], Loss: 0.00844691938453022
Epoch [20/20], Loss: 0.008324573892971565


In [46]:
model_path = 'models/mhist_original.pth'
torch.save(model.state_dict(), model_path)

In [50]:
model = get_network(model='ConvNetD7', channel=3, num_classes=2, im_size=(224, 224)) 
model.load_state_dict(torch.load(model_path))


<All keys matched successfully>

In [51]:
model.eval()
correct, total = 0, 0

with torch.no_grad():
    for inputs, labels in test_loader:
        labels = labels.long()
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f'Test Accuracy: {accuracy:.2f}%')

Test Accuracy: 80.35%


In [52]:
from ptflops import get_model_complexity_info

macs, params = get_model_complexity_info(model, (3, 224, 224), as_strings=True, print_per_layer_stat=True)
print(f"MACs: {macs}, Parameters: {params}")

ConvNet(
  891.14 k, 100.000% Params, 2.68 GMac, 99.365% MACs, 
  (features): Sequential(
    890.88 k, 99.971% Params, 2.68 GMac, 99.365% MACs, 
    (0): Conv2d(3.58 k, 0.402% Params, 179.83 MMac, 6.664% MACs, 3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): GroupNorm(256, 0.029% Params, 12.85 MMac, 0.476% MACs, 128, 128, eps=1e-05, affine=True)
    (2): ReLU(0, 0.000% Params, 6.42 MMac, 0.238% MACs, inplace=True)
    (3): AvgPool2d(0, 0.000% Params, 6.42 MMac, 0.238% MACs, kernel_size=2, stride=2, padding=0)
    (4): Conv2d(147.58 k, 16.561% Params, 1.85 GMac, 68.604% MACs, 128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): GroupNorm(256, 0.029% Params, 3.21 MMac, 0.119% MACs, 128, 128, eps=1e-05, affine=True)
    (6): ReLU(0, 0.000% Params, 1.61 MMac, 0.060% MACs, inplace=True)
    (7): AvgPool2d(0, 0.000% Params, 1.61 MMac, 0.060% MACs, kernel_size=2, stride=2, padding=0)
    (8): Conv2d(147.58 k, 16.561% Params, 462.82 MMac, 17.151% MACs, 128,

In [54]:
total_flops = 0

for inputs, labels in test_loader:
    total_flops += get_model_complexity_info(model, (3, 224, 224), as_strings=False)[0]

print(f"Total FLOPs for the test dataset: {total_flops}")

ConvNet(
  891.14 k, 100.000% Params, 2.68 GMac, 99.365% MACs, 
  (features): Sequential(
    890.88 k, 99.971% Params, 2.68 GMac, 99.365% MACs, 
    (0): Conv2d(3.58 k, 0.402% Params, 179.83 MMac, 6.664% MACs, 3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): GroupNorm(256, 0.029% Params, 12.85 MMac, 0.476% MACs, 128, 128, eps=1e-05, affine=True)
    (2): ReLU(0, 0.000% Params, 6.42 MMac, 0.238% MACs, inplace=True)
    (3): AvgPool2d(0, 0.000% Params, 6.42 MMac, 0.238% MACs, kernel_size=2, stride=2, padding=0)
    (4): Conv2d(147.58 k, 16.561% Params, 1.85 GMac, 68.604% MACs, 128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): GroupNorm(256, 0.029% Params, 3.21 MMac, 0.119% MACs, 128, 128, eps=1e-05, affine=True)
    (6): ReLU(0, 0.000% Params, 1.61 MMac, 0.060% MACs, inplace=True)
    (7): AvgPool2d(0, 0.000% Params, 1.61 MMac, 0.060% MACs, kernel_size=2, stride=2, padding=0)
    (8): Conv2d(147.58 k, 16.561% Params, 462.82 MMac, 17.151% MACs, 128,

In [48]:
# 2b 
num_classes = 2 
synthetic_images = torch.randn(num_classes * 50, 3, 224, 224, requires_grad=True)
synthetic_labels = torch.tensor([i // 50 for i in range(num_classes * 50)], dtype=torch.long) 

synthetic_optimizer = optim.SGD([synthetic_images], lr=0.1)

In [ ]:
net = get_network(model='ConvNetD7', channel=3, num_classes=2, im_size=(224, 224)) 
net.train()

for param in list(net.parameters()):
    param.requires_grad = False

In [ ]:
def mse_b(real, syn, num_classes, batch_real, ipc):
    real_mean = torch.mean(real.reshape(num_classes, batch_real, -1), dim=1).cpu()
    syn_mean = torch.mean(syn.reshape(num_classes, ipc, -1).cpu(), dim=1)
  
    err = torch.sum((real_mean - syn_mean)**2)
    return err