In [1]:
# regular data

import torch
import torch.nn as nn
from models import MAEModel
from utils.others import set_global_seed, log_mae_recon
from utils.lin_probe import run_linear_probe
from utils.shuffle import patchify, unpatchify
from utils.losses import mae_loss
from loaders import PatchedCIFARLoader
from loaders1 import ColoredMNISTLoader, ColoredFMNISTLoader
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import argparse
import yaml
import os
tasks = ['mnist', 'fmnist', 'cifar10']
for task in tasks:
    scores = []
    for seed in [0, 37, 59, 289, 334]:
        IMG_SIZE = 32
        PATCH_SIZE = 4
        PATCH_CHANNEL = 3
        BATCH_SIZE = 512
        NUM_CLASSES = 10
        LR = 1e-3
        WEIGHT_DECAY = 0.05
        MASK_RATIO = 0.75
        EPOCHS = 100
        NUM_WORKERS = 6
        ENC_DIM = 192
        set_global_seed(seed)

        if "mnist" in task.lower():
            PATCH_CHANNEL = 1
            if task.lower() == "mnist":
                mean = [0.1307]
                std = [0.3081]
                Dts = datasets.MNIST
            elif task.lower() == "fmnist":
                mean = [0.2860]
                std = [0.3530]
                Dts = datasets.FashionMNIST
        elif "cifar10" in task.lower():
            PATCH_CHANNEL = 3
            mean = [0.4914, 0.4822, 0.4465]
            std = [0.2470, 0.2435, 0.2616]
            Dts = datasets.CIFAR10
        else:
            raise NotImplementedError(f"Dataset {task} not supported")
        
        device = "cuda" if torch.cuda.is_available() else "cpu"

        transform = transforms.Compose([
        transforms.Resize(IMG_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
        ])

        ds = Dts(root='../data', train=True, download=True,
                                    transform=transform)
        tr_ds = Dts(root='../data', train=True, download=True,
                                    transform=transform)
        te_ds = Dts(root='../data', train=False, download=True,
                                    transform=transform)

        tr_loader = DataLoader(tr_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
        te_loader = DataLoader(te_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

        model = MAEModel(img_size=IMG_SIZE,
                        patch_size=PATCH_SIZE,
                        patch_ch=PATCH_CHANNEL,
                        enc_dim=ENC_DIM,
                        mask_ratio=MASK_RATIO)
        model.encoder.load_state_dict(torch.load(f'./ckpt/mae_{task.lower()}_encoder_{ENC_DIM}_37_last.pth'))
        model.decoder.load_state_dict(torch.load(f'./ckpt/mae_{task.lower()}_decoder_{ENC_DIM}_37_last.pth'))
        model.to(device)
        model.eval()
        lp_acc = run_linear_probe(
            encoder=model.encoder,
            train_loader=tr_loader, 
            val_loader=te_loader,
            num_classes=NUM_CLASSES,
            device=device,
            probe_epochs=5,
            lr=1e-3
            )
        print(f"Linear probe accuracy: {lp_acc:.4f}")
        scores.append(lp_acc)
    print(f"Final results for {task}:")
    print(f"Mean accuracy: {np.mean(scores):.4f}, Std: {np.std(scores):.4f}")



Using global average pooling for linear probe.
Linear probe accuracy: 0.9899
Using global average pooling for linear probe.
Linear probe accuracy: 0.9895
Using global average pooling for linear probe.
Linear probe accuracy: 0.9900
Using global average pooling for linear probe.
Linear probe accuracy: 0.9898
Using global average pooling for linear probe.
Linear probe accuracy: 0.9897
Final results for mnist:
Mean accuracy: 0.9898, Std: 0.0002
Using global average pooling for linear probe.
Linear probe accuracy: 0.8840
Using global average pooling for linear probe.
Linear probe accuracy: 0.8848
Using global average pooling for linear probe.
Linear probe accuracy: 0.8840
Using global average pooling for linear probe.
Linear probe accuracy: 0.8852
Using global average pooling for linear probe.
Linear probe accuracy: 0.8832
Final results for fmnist:
Mean accuracy: 0.8842, Std: 0.0007
Using global average pooling for linear probe.
Linear probe accuracy: 0.6212
Using global average pooling for

In [None]:
# wm probe on clean data

import torch
import torch.nn as nn
from models import MAEModel
from utils.others import set_global_seed, log_mae_recon
from utils.lin_probe import run_linear_probe
from utils.shuffle import patchify, unpatchify
from utils.losses import mae_loss
from loaders_wm import ColoredMNISTLoader, ColoredFMNISTLoader, PatchedCIFARLoader
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import argparse
import yaml
import os
tasks = ['mnist', 'fmnist', 'cifar10']
for task in tasks:
    scores = []
    for seed in [0, 37, 59, 289, 334]:
        IMG_SIZE = 32
        PATCH_SIZE = 4
        PATCH_CHANNEL = 3
        BATCH_SIZE = 512
        NUM_CLASSES = 10
        LR = 1e-3
        WEIGHT_DECAY = 0.05
        MASK_RATIO = 0.75
        EPOCHS = 100
        NUM_WORKERS = 6
        ENC_DIM = 192
        set_global_seed(seed)

        class GrayToRGB:
            def __call__(self, x):
                # x: (1, H, W)
                return x.repeat(3, 1, 1)
            
        device = "cuda" if torch.cuda.is_available() else "cpu"

        if "mnist" in task.lower():
            PATCH_CHANNEL = 3
        if task.lower() == "mnist":
            mean = [0.1307]
            std = [0.3081]
            Dts = datasets.MNIST
            # loader = ColoredMNISTLoader(batch_size=BATCH_SIZE, device=device)
        elif task.lower() == "fmnist":
            mean = [0.2860]
            std = [0.3530]
            Dts = datasets.FashionMNIST
            # loader = ColoredFMNISTLoader(batch_size=BATCH_SIZE, device=device)
        elif "cifar10" in task.lower():
            PATCH_CHANNEL = 3
            mean = [0.4914, 0.4822, 0.4465]
            std = [0.2470, 0.2435, 0.2616]
            Dts = datasets.CIFAR10
            # loader = PatchedCIFARLoader(batch_size=BATCH_SIZE, device=device)
        else:
            raise NotImplementedError(f"Dataset {task} not supported")
        
        if "mnist" in task.lower():
            transform = transforms.Compose([
            transforms.Resize(IMG_SIZE),
            transforms.ToTensor(),
            GrayToRGB(),
            transforms.Normalize(mean=mean,
                                std=std),
                                ])
        elif "cifar10" in task.lower():
            transform = transforms.Compose([
            transforms.Resize(IMG_SIZE),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean,
                                std=std),
                                ])
        
    

        ds = Dts(root='../data', train=True, download=True,
                                    transform=transform)
        tr_ds = Dts(root='../data', train=True, download=True,
                                    transform=transform)
        te_ds = Dts(root='../data', train=False, download=True,
                                    transform=transform)

        tr_loader = DataLoader(tr_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
        te_loader = DataLoader(te_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

        model = MAEModel(img_size=IMG_SIZE,
                        patch_size=PATCH_SIZE,
                        patch_ch=PATCH_CHANNEL,
                        enc_dim=ENC_DIM,
                        mask_ratio=MASK_RATIO)
        model.encoder.load_state_dict(torch.load(f'./ckpt/mae_{task.lower()}_wm_encoder_{ENC_DIM}_37_last.pth'))
        model.decoder.load_state_dict(torch.load(f'./ckpt/mae_{task.lower()}_wm_decoder_{ENC_DIM}_37_last.pth'))
        model.to(device)
        model.eval()
        lp_acc = run_linear_probe(
            encoder=model.encoder,
            train_loader=tr_loader, 
            val_loader=te_loader,
            num_classes=NUM_CLASSES,
            device=device,
            probe_epochs=5,
            lr=1e-3
            )
        print(f"Linear probe accuracy: {lp_acc:.4f}")
        scores.append(lp_acc)
    print(f"Final results for {task}_wm:")
    print(f"Mean accuracy: {np.mean(scores):.4f}, Std: {np.std(scores):.4f}")



Using global average pooling for linear probe.
Linear probe accuracy: 0.8626
Using global average pooling for linear probe.
Linear probe accuracy: 0.8642
Using global average pooling for linear probe.
Linear probe accuracy: 0.8650
Using global average pooling for linear probe.
Linear probe accuracy: 0.8666
Using global average pooling for linear probe.
Linear probe accuracy: 0.8612
Final results for mnist_wm:
Mean accuracy: 0.8639, Std: 0.0019
Using global average pooling for linear probe.
Linear probe accuracy: 0.7302
Using global average pooling for linear probe.
Linear probe accuracy: 0.7292
Using global average pooling for linear probe.
Linear probe accuracy: 0.7252
Using global average pooling for linear probe.
Linear probe accuracy: 0.7299
Using global average pooling for linear probe.
Linear probe accuracy: 0.7309
Final results for fmnist_wm:
Mean accuracy: 0.7291, Std: 0.0020
Using global average pooling for linear probe.
Linear probe accuracy: 0.4413
Using global average pooli

In [1]:
# wm probe on wm data

import torch
import torch.nn as nn
from models import MAEModel
from utils.others import set_global_seed, log_mae_recon
from utils.lin_probe import run_linear_probe, run_linear_probe_wm
from utils.shuffle import patchify, unpatchify
from utils.losses import mae_loss
from loaders_wm import ColoredMNISTLoader, ColoredFMNISTLoader, PatchedCIFARLoader
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import argparse
import yaml
import os
tasks = ['mnist', 'fmnist', 'cifar10']
for task in tasks:
    scores = []
    for seed in [0, 37, 59, 289, 334]:
        IMG_SIZE = 32
        PATCH_SIZE = 4
        PATCH_CHANNEL = 3
        BATCH_SIZE = 512
        NUM_CLASSES = 10
        LR = 1e-3
        WEIGHT_DECAY = 0.05
        MASK_RATIO = 0.75
        EPOCHS = 100
        NUM_WORKERS = 6
        ENC_DIM = 192
        set_global_seed(seed)

        class GrayToRGB:
            def __call__(self, x):
                # x: (1, H, W)
                return x.repeat(3, 1, 1)
            
        device = "cuda" if torch.cuda.is_available() else "cpu"

        if "mnist" in task.lower():
            PATCH_CHANNEL = 3
        if task.lower() == "mnist":
            mean = [0.1307]
            std = [0.3081]
            Dts = datasets.MNIST
            loader = ColoredMNISTLoader(batch_size=BATCH_SIZE, device=device,seed =seed)
        elif task.lower() == "fmnist":
            mean = [0.2860]
            std = [0.3530]
            Dts = datasets.FashionMNIST
            loader = ColoredFMNISTLoader(batch_size=BATCH_SIZE, device=device,seed=seed)
        elif "cifar10" in task.lower():
            PATCH_CHANNEL = 3
            mean = [0.4914, 0.4822, 0.4465]
            std = [0.2470, 0.2435, 0.2616]
            Dts = datasets.CIFAR10
            loader = PatchedCIFARLoader(batch_size=BATCH_SIZE, device=device,seed=seed)
        else:
            raise NotImplementedError(f"Dataset {task} not supported")
        
        if "mnist" in task.lower():
            transform = transforms.Compose([
            transforms.Resize(IMG_SIZE),
            transforms.ToTensor(),
            GrayToRGB(),
            transforms.Normalize(mean=mean,
                                std=std),
                                ])
        elif "cifar10" in task.lower():
            transform = transforms.Compose([
            transforms.Resize(IMG_SIZE),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean,
                                std=std),
                                ])
        
    

        ds = Dts(root='../data', train=True, download=True,
                                    transform=transform)
        tr_ds = Dts(root='../data', train=True, download=True,
                                    transform=transform)
        te_ds = Dts(root='../data', train=False, download=True,
                                    transform=transform)

        tr_loader = DataLoader(tr_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
        te_loader = DataLoader(te_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

        model = MAEModel(img_size=IMG_SIZE,
                        patch_size=PATCH_SIZE,
                        patch_ch=PATCH_CHANNEL,
                        enc_dim=ENC_DIM,
                        mask_ratio=MASK_RATIO)
        model.encoder.load_state_dict(torch.load(f'./ckpt/mae_{task.lower()}_wm_encoder_{ENC_DIM}_37_last.pth'))
        model.decoder.load_state_dict(torch.load(f'./ckpt/mae_{task.lower()}_wm_decoder_{ENC_DIM}_37_last.pth'))
        model.to(device)
        model.eval()
        lp_acc = run_linear_probe_wm(
            encoder=model.encoder,
            train_loader=loader, 
            val_loader=loader,
            num_classes=NUM_CLASSES,
            device=device,
            probe_epochs=5,
            lr=1e-3
            )
        print(f"Linear probe accuracy: {lp_acc:.4f}")
        scores.append(lp_acc)
    print(f"Final results for {task}_wm:")
    print(f"Mean accuracy: {np.mean(scores):.4f}, Std: {np.std(scores):.4f}")



Loading MNIST to memory...
Loaded 60000 Train images and 10000 Test images to cuda
Using global average pooling for linear probe.
Linear probe accuracy: 0.9530
Loading MNIST to memory...
Loaded 60000 Train images and 10000 Test images to cuda
Using global average pooling for linear probe.
Linear probe accuracy: 0.9505
Loading MNIST to memory...
Loaded 60000 Train images and 10000 Test images to cuda
Using global average pooling for linear probe.
Linear probe accuracy: 0.9534
Loading MNIST to memory...
Loaded 60000 Train images and 10000 Test images to cuda
Using global average pooling for linear probe.
Linear probe accuracy: 0.9521
Loading MNIST to memory...
Loaded 60000 Train images and 10000 Test images to cuda
Using global average pooling for linear probe.
Linear probe accuracy: 0.9524
Final results for mnist_wm:
Mean accuracy: 0.9523, Std: 0.0010
Loading FashionMNIST to memory...
Loaded 60000 Train images and 10000 Test images to cuda
Using global average pooling for linear probe.
