In [1]:
from src.model import VAE
from src.trainer import SimCLRTrainer
from src.losses import accurary
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split, ConcatDataset


from tqdm import tqdm

%load_ext autoreload
%autoreload 2

In [2]:
from corruption_utils import corruptions

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
mnist = torchvision.datasets.MNIST("../data", train=True, download=True)

In [8]:
majority = list(filter(lambda elem: elem[1] != 5, mnist)) # not 5
minority = list(filter(lambda elem: elem[1] == 5, mnist)) # 5

In [9]:
torch.manual_seed(101)

<torch._C.Generator at 0x20867774950>

In [37]:
from src.utils import CMNISTGenerator, CMNIST

cmnist_generator = CMNISTGenerator(
    majority, # assign rich styles to non-5 digits
    {
        corruptions.identity: 0.3,
        corruptions.stripe: 0.3,
        corruptions.zigzag: 0.2,
        corruptions.canny_edges: 0.2,
    },
)
cmnist = CMNIST(
    cmnist_generator, 
    transforms.Compose([transforms.ToTensor(), lambda img: img / 255.0])
)


Generating dataset: 100%|██████████| 54579/54579 [00:07<00:00, 7405.64item/s]


In [38]:
train_5, test_5 = random_split(minority, [2500, len(minority) - 2500])

In [39]:
train_5_generator = CMNISTGenerator(
    train_5,
    {
        corruptions.stripe: 0.3,
        corruptions.zigzag: 0.3,
        corruptions.identity: 0.4
    },
)
train_5 = CMNIST(
    train_5_generator, 
    transforms.Compose([transforms.ToTensor(), lambda img: img / 255.0])
)

Generating dataset: 100%|██████████| 2500/2500 [00:00<00:00, 8865.72item/s]


In [40]:
train = ConcatDataset([cmnist, train_5])
# train = train_5
train_size = int(0.8 * len(train))
train, valid = random_split(train, [train_size, len(train) - train_size])

In [41]:
test_5_generator = CMNISTGenerator(
    test_5,
    {
        corruptions.canny_edges: 1.0
    },
)
test_5 = CMNIST(
    test_5_generator, 
    transforms.Compose([transforms.ToTensor(), lambda img: img / 255.0])
)

Generating dataset: 100%|██████████| 2921/2921 [00:00<00:00, 4457.96item/s]


In [42]:
train_loader = DataLoader(train, batch_size=128, shuffle=True)
valid_loader = DataLoader(valid, batch_size=128, shuffle=False)
test_loader = DataLoader(test_5, batch_size=128, shuffle=False)

### CNN classifier

In [43]:
from src.model import SimpleCNNClassifier
from src.trainer import SimpleCNNTrainer

cnn = SimpleCNNClassifier(n_class=10).to(device)
optimizer = torch.optim.Adam(cnn.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()
trainer = SimpleCNNTrainer(cnn, optimizer, criterion, 5, device)

In [44]:
trainer.fit(26, train_loader, valid_loader)

epoch 0: 100%|██████████| 357/357 [00:01<00:00, 211.04batch/s, acc=0.895, loss=0.426]
val-epoch 0: 100%|██████████| 90/90 [00:00<00:00, 536.91it/s]


val_acc=0.921


epoch 5: 100%|██████████| 357/357 [00:01<00:00, 234.72batch/s, acc=1, loss=0.031]     
val-epoch 5: 100%|██████████| 90/90 [00:00<00:00, 532.30it/s]


val_acc=0.973


epoch 10: 100%|██████████| 357/357 [00:01<00:00, 235.05batch/s, acc=1, loss=0.0146]    
val-epoch 10: 100%|██████████| 90/90 [00:00<00:00, 504.03it/s]


val_acc=0.977


epoch 15: 100%|██████████| 357/357 [00:01<00:00, 237.26batch/s, acc=1, loss=0.00807]   
val-epoch 15: 100%|██████████| 90/90 [00:00<00:00, 523.25it/s]


val_acc=0.978


epoch 20: 100%|██████████| 357/357 [00:01<00:00, 238.99batch/s, acc=1, loss=0.00594]   
val-epoch 20: 100%|██████████| 90/90 [00:00<00:00, 516.33it/s]


val_acc=0.978


epoch 25: 100%|██████████| 357/357 [00:01<00:00, 236.11batch/s, acc=1, loss=0.00257] 
val-epoch 25: 100%|██████████| 90/90 [00:00<00:00, 523.25it/s]

val_acc=0.979





In [45]:
trainer._valid(test_loader, True, epoch_id=0)

val-epoch 0: 100%|██████████| 23/23 [00:00<00:00, 273.87it/s]

val_acc=0.829





### CD-VAE zero-shot

In [46]:
vae = VAE(total_z_dim=32).to(device)
optimizer = torch.optim.Adam(vae.parameters(), lr=3e-4)
trainer = SimCLRTrainer(
    vae,
    optimizer,
    sim_fn="cosine",
    hyperparameter={"temperature": 0.1, "beta": 100},
    verbose_period=20,
    device=device,
)

In [47]:
trainer.fit(101, train_loader)

Epoch 0: 100%|██████████| 357/357 [00:03<00:00, 96.27batch/s, c_loss=0.221, vae_loss=76.9] 
Epoch 20: 100%|██████████| 357/357 [00:03<00:00, 108.28batch/s, c_loss=0.0117, vae_loss=23.1] 
Epoch 40: 100%|██████████| 357/357 [00:03<00:00, 110.34batch/s, c_loss=0.00317, vae_loss=17.5]
Epoch 60: 100%|██████████| 357/357 [00:03<00:00, 106.60batch/s, c_loss=0.00443, vae_loss=16]  
Epoch 80: 100%|██████████| 357/357 [00:03<00:00, 108.43batch/s, c_loss=0.00409, vae_loss=14.7]
Epoch 100: 100%|██████████| 357/357 [00:03<00:00, 114.89batch/s, c_loss=0.00292, vae_loss=14.8]


In [48]:
vae.eval()
for p in vae.parameters():
    p.requires_grad = False

In [49]:
model = torch.nn.Sequential(
    torch.nn.Linear(16, 256),
    torch.nn.ReLU(),
    torch.nn.Linear(256, 10)
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = torch.nn.CrossEntropyLoss()

In [50]:
EPOCHS = 26
VERBOSE_PERIOD = 5

for epoch in range(EPOCHS):
    verbose = (epoch % VERBOSE_PERIOD) == 0
    model.train()
    with tqdm(train_loader, unit="batch", disable=not verbose) as bar:
        bar.set_description(f"epoch {epoch}")
        for X_batch, y_batch in bar:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            mu_c, logvar_c = vae.encode(X_batch)[:2]
            logits = model(mu_c)
            # logits = model(torch.cat([mu_c, logvar_c], dim=1))
            loss = criterion(logits, y_batch)
            loss.backward()
            optimizer.step()

            # update running stats
            acc = accurary(logits, y_batch)
            bar.set_postfix(loss=float(loss), acc=float(acc))
    model.eval()
    total_acc = 0
    with torch.no_grad():
        for X_batch, y_batch in tqdm(valid_loader, disable=not verbose):
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            mu_c, logvar_c = vae.encode(X_batch)[:2]
            logits = model(mu_c)
            # logits = model(torch.cat([mu_c, logvar_c], dim=1))
            acc = accurary(logits, y_batch)
            total_acc += acc.item()
    if verbose:
        print("val_acc={:.3f}".format(total_acc / len(valid_loader)))
    
            


epoch 0: 100%|██████████| 357/357 [00:01<00:00, 191.10batch/s, acc=1, loss=0.0652]   
100%|██████████| 90/90 [00:00<00:00, 465.65it/s]


val_acc=0.983


epoch 5: 100%|██████████| 357/357 [00:01<00:00, 258.40batch/s, acc=1, loss=0.00363] 
100%|██████████| 90/90 [00:00<00:00, 463.95it/s]


val_acc=0.983


epoch 10: 100%|██████████| 357/357 [00:01<00:00, 260.76batch/s, acc=1, loss=0.000589]
100%|██████████| 90/90 [00:00<00:00, 439.02it/s]


val_acc=0.984


epoch 15: 100%|██████████| 357/357 [00:01<00:00, 252.56batch/s, acc=1, loss=0.000123]
100%|██████████| 90/90 [00:00<00:00, 452.25it/s]


val_acc=0.984


epoch 20: 100%|██████████| 357/357 [00:01<00:00, 263.21batch/s, acc=1, loss=0.000315]
100%|██████████| 90/90 [00:00<00:00, 432.53it/s]


val_acc=0.984


epoch 25: 100%|██████████| 357/357 [00:01<00:00, 256.20batch/s, acc=1, loss=1.78e-5] 
100%|██████████| 90/90 [00:00<00:00, 436.89it/s]

val_acc=0.984





In [51]:
model.eval()
total_acc = 0
with torch.no_grad():
    for X_batch, y_batch in tqdm(test_loader, disable=not verbose):
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        mu_c, logvar_c = vae.encode(X_batch)[:2]
        logits = model(mu_c)
        # logits = model(torch.cat([mu_c, logvar_c], dim=1))
        acc = accurary(logits, y_batch)
        total_acc += acc.item()
if verbose:
    print("val_acc={:.3f}".format(total_acc / len(test_loader)))

100%|██████████| 23/23 [00:00<00:00, 323.94it/s]

val_acc=0.676



