In [21]:
from src.model import VAE
from src.trainer import SimCLRTrainer
from src.losses import accurary
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split, ConcatDataset


from tqdm import tqdm

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [22]:
from corruption_utils import corruptions

In [23]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [24]:
mnist = torchvision.datasets.MNIST("../data", train=True, download=True)

In [25]:
majority = list(filter(lambda elem: elem[1] != 5, mnist)) # not 5
minority = list(filter(lambda elem: elem[1] == 5, mnist)) # 5

In [26]:
from src.utils import CMNISTGenerator, CMNIST

cmnist_generator = CMNISTGenerator(
    majority, # assign rich styles to non-5 digits
    {
        corruptions.identity: 0.3,
        corruptions.stripe: 0.3,
        corruptions.zigzag: 0.2,
        corruptions.canny_edges: 0.2,
    },
)
cmnist = CMNIST(
    cmnist_generator, 
    transforms.Compose([transforms.ToTensor(), lambda img: img / 255.0])
)


Generating dataset: 100%|██████████| 54579/54579 [00:07<00:00, 7023.53item/s]


In [27]:
train_5, test_5 = random_split(minority, [2500, len(minority) - 2500])

In [28]:
train_5_generator = CMNISTGenerator(
    train_5,
    {
        corruptions.identity: 1.0
    },
)
train_5 = CMNIST(
    train_5_generator, 
    transforms.Compose([transforms.ToTensor(), lambda img: img / 255.0])
)

Generating dataset: 100%|██████████| 2500/2500 [00:00<00:00, 57803.99item/s]


In [29]:
train = ConcatDataset([cmnist, train_5])
train, valid = random_split(train, [46000, len(train) - 46000])

In [30]:
test_5_generator = CMNISTGenerator(
    test_5,
    {
        # corruptions.stripe: 0.3,
        # corruptions.zigzag: 0.3,
        # corruptions.canny_edges: 0.4,
        corruptions.dotted_line: 0.5,
        corruptions.brightness: 0.5
    },
)
test_5 = CMNIST(
    test_5_generator, 
    transforms.Compose([transforms.ToTensor(), lambda img: img / 255.0])
)

Generating dataset: 100%|██████████| 2921/2921 [00:00<00:00, 4640.15item/s]


In [31]:
train_loader = DataLoader(train, batch_size=128, shuffle=True)
valid_loader = DataLoader(valid, batch_size=128, shuffle=False)
test_loader = DataLoader(test_5, batch_size=128, shuffle=False)

### CNN classifier

In [32]:
from src.model import SimpleCNNClassifier
from src.trainer import SimpleCNNTrainer

cnn = SimpleCNNClassifier(n_class=10).to(device)
optimizer = torch.optim.Adam(cnn.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()
trainer = SimpleCNNTrainer(cnn, optimizer, criterion, 5, device)

In [33]:
trainer.fit(31, train_loader, valid_loader)

epoch 0: 100%|██████████| 360/360 [00:02<00:00, 171.30batch/s, acc=0.854, loss=0.334]
val-epoch 0: 100%|██████████| 87/87 [00:00<00:00, 358.13it/s]


val_acc=0.917


epoch 5: 100%|██████████| 360/360 [00:01<00:00, 211.75batch/s, acc=0.958, loss=0.143] 
val-epoch 5: 100%|██████████| 87/87 [00:00<00:00, 396.75it/s]


val_acc=0.974


epoch 10: 100%|██████████| 360/360 [00:01<00:00, 211.99batch/s, acc=1, loss=0.0183]    
val-epoch 10: 100%|██████████| 87/87 [00:00<00:00, 369.37it/s]


val_acc=0.978


epoch 15: 100%|██████████| 360/360 [00:01<00:00, 209.37batch/s, acc=1, loss=0.00773]   
val-epoch 15: 100%|██████████| 87/87 [00:00<00:00, 373.61it/s]


val_acc=0.980


epoch 20: 100%|██████████| 360/360 [00:01<00:00, 210.22batch/s, acc=1, loss=0.0061]    
val-epoch 20: 100%|██████████| 87/87 [00:00<00:00, 393.30it/s]


val_acc=0.979


epoch 25: 100%|██████████| 360/360 [00:01<00:00, 213.34batch/s, acc=1, loss=0.000469]
val-epoch 25: 100%|██████████| 87/87 [00:00<00:00, 392.74it/s]


val_acc=0.980


epoch 30: 100%|██████████| 360/360 [00:01<00:00, 233.21batch/s, acc=1, loss=0.000811]
val-epoch 30: 100%|██████████| 87/87 [00:00<00:00, 500.00it/s]

val_acc=0.980





In [34]:
trainer._valid(test_loader, True, epoch_id=0)

val-epoch 0: 100%|██████████| 23/23 [00:00<00:00, 522.73it/s]

val_acc=0.438





### CD-VAE zero-shot

In [35]:
vae = VAE(total_z_dim=32).to(device)
optimizer = torch.optim.Adam(vae.parameters(), lr=3e-4)
trainer = SimCLRTrainer(
    vae,
    optimizer,
    sim_fn="cosine",
    hyperparameter={"temperature": 0.1, "beta": 100},
    verbose_period=5,
    device=device,
)

In [36]:
trainer.fit(31, train_loader)

Epoch 0: 100%|██████████| 360/360 [00:03<00:00, 103.88batch/s, c_loss=0.106, vae_loss=79]  
Epoch 5: 100%|██████████| 360/360 [00:03<00:00, 107.98batch/s, c_loss=0.0413, vae_loss=41.8]
Epoch 10: 100%|██████████| 360/360 [00:03<00:00, 105.84batch/s, c_loss=0.063, vae_loss=33]    
Epoch 15: 100%|██████████| 360/360 [00:03<00:00, 108.86batch/s, c_loss=0.0126, vae_loss=25.4] 
Epoch 20: 100%|██████████| 360/360 [00:03<00:00, 104.42batch/s, c_loss=0.0267, vae_loss=21.4] 
Epoch 25: 100%|██████████| 360/360 [00:03<00:00, 108.92batch/s, c_loss=0.0355, vae_loss=20.1] 
Epoch 30: 100%|██████████| 360/360 [00:03<00:00, 107.21batch/s, c_loss=0.00682, vae_loss=21.8]


In [37]:
vae.eval()
for p in vae.parameters():
    p.requires_grad = False

In [38]:
model = torch.nn.Sequential(
    torch.nn.Linear(16, 256),
    torch.nn.ReLU(),
    torch.nn.Linear(256, 10)
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = torch.nn.CrossEntropyLoss()

In [39]:
EPOCHS = 31
VERBOSE_PERIOD = 5

for epoch in range(EPOCHS):
    verbose = (epoch % VERBOSE_PERIOD) == 0
    model.train()
    with tqdm(train_loader, unit="batch", disable=not verbose) as bar:
        bar.set_description(f"epoch {epoch}")
        for X_batch, y_batch in bar:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            mu_c, logvar_c = vae.encode(X_batch)[:2]
            logits = model(mu_c)
            # logits = model(torch.cat([mu_c, logvar_c], dim=1))
            loss = criterion(logits, y_batch)
            loss.backward()
            optimizer.step()

            # update running stats
            acc = accurary(logits, y_batch)
            bar.set_postfix(loss=float(loss), acc=float(acc))
    model.eval()
    total_acc = 0
    with torch.no_grad():
        for X_batch, y_batch in tqdm(valid_loader, disable=not verbose):
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            mu_c, logvar_c = vae.encode(X_batch)[:2]
            logits = model(mu_c)
            # logits = model(torch.cat([mu_c, logvar_c], dim=1))
            acc = accurary(logits, y_batch)
            total_acc += acc.item()
    if verbose:
        print("val_acc={:.3f}".format(total_acc / len(valid_loader)))
    
            


epoch 0: 100%|██████████| 360/360 [00:01<00:00, 255.85batch/s, acc=1, loss=0.059]    
100%|██████████| 87/87 [00:00<00:00, 411.94it/s]


val_acc=0.984


epoch 5: 100%|██████████| 360/360 [00:01<00:00, 250.19batch/s, acc=1, loss=0.00329]   
100%|██████████| 87/87 [00:00<00:00, 328.48it/s]


val_acc=0.985


epoch 10: 100%|██████████| 360/360 [00:01<00:00, 247.13batch/s, acc=1, loss=0.00316]   
100%|██████████| 87/87 [00:00<00:00, 337.47it/s]


val_acc=0.985


epoch 15: 100%|██████████| 360/360 [00:01<00:00, 257.19batch/s, acc=1, loss=4.5e-5]   
100%|██████████| 87/87 [00:00<00:00, 330.83it/s]


val_acc=0.985


epoch 20: 100%|██████████| 360/360 [00:01<00:00, 251.64batch/s, acc=1, loss=5.25e-5]   
100%|██████████| 87/87 [00:00<00:00, 339.75it/s]


val_acc=0.985


epoch 25: 100%|██████████| 360/360 [00:01<00:00, 260.61batch/s, acc=1, loss=1.29e-6]    
100%|██████████| 87/87 [00:00<00:00, 346.71it/s]


val_acc=0.985


epoch 30: 100%|██████████| 360/360 [00:01<00:00, 238.69batch/s, acc=1, loss=7.77e-7] 
100%|██████████| 87/87 [00:00<00:00, 328.27it/s]

val_acc=0.985





In [40]:
model.eval()
total_acc = 0
with torch.no_grad():
    for X_batch, y_batch in tqdm(test_loader, disable=not verbose):
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        mu_c, logvar_c = vae.encode(X_batch)[:2]
        logits = model(mu_c)
        # logits = model(torch.cat([mu_c, logvar_c], dim=1))
        acc = accurary(logits, y_batch)
        total_acc += acc.item()
if verbose:
    print("val_acc={:.3f}".format(total_acc / len(test_loader)))

100%|██████████| 23/23 [00:00<00:00, 306.55it/s]

val_acc=0.480



