In [1]:
from src.model import VAE
from src.trainer import SimCLRTrainer
from src.losses import accurary
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split, ConcatDataset


from tqdm import tqdm

%load_ext autoreload
%autoreload 2

In [2]:
from corruption_utils import corruptions

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
mnist = torchvision.datasets.MNIST("../data", train=True, download=True)

In [5]:
majority = list(filter(lambda elem: elem[1] != 5, mnist)) # not 5
minority = list(filter(lambda elem: elem[1] == 5, mnist)) # 5

In [41]:
torch.manual_seed(101)

<torch._C.Generator at 0x26b7e6148f0>

In [42]:
from src.utils import CMNISTGenerator, CMNIST

cmnist_generator = CMNISTGenerator(
    majority, # assign rich styles to non-5 digits
    {
        corruptions.identity: 0.3,
        corruptions.stripe: 0.3,
        corruptions.zigzag: 0.2,
        corruptions.canny_edges: 0.2,
    },
)
cmnist = CMNIST(
    cmnist_generator, 
    transforms.Compose([transforms.ToTensor(), lambda img: img / 255.0])
)


Generating dataset: 100%|██████████| 54579/54579 [00:08<00:00, 6719.04item/s]


In [43]:
train_5, test_5 = random_split(minority, [2500, len(minority) - 2500])

In [44]:
train_5_generator = CMNISTGenerator(
    train_5,
    {
        corruptions.identity: 1.0
    },
)
train_5 = CMNIST(
    train_5_generator, 
    transforms.Compose([transforms.ToTensor(), lambda img: img / 255.0])
)

Generating dataset: 100%|██████████| 2500/2500 [00:00<00:00, 30601.51item/s]


In [45]:
train = ConcatDataset([cmnist, train_5])
# train = train_5
train_size = int(0.8 * len(train))
train, valid = random_split(train, [train_size, len(train) - train_size])

In [46]:
test_5_generator = CMNISTGenerator(
    test_5,
    {
        corruptions.stripe: 0.3,
        corruptions.zigzag: 0.3,
        corruptions.canny_edges: 0.4
    },
)
test_5 = CMNIST(
    test_5_generator, 
    transforms.Compose([transforms.ToTensor(), lambda img: img / 255.0])
)

Generating dataset: 100%|██████████| 2921/2921 [00:00<00:00, 3779.47item/s]


In [47]:
train_loader = DataLoader(train, batch_size=128, shuffle=True)
valid_loader = DataLoader(valid, batch_size=128, shuffle=False)
test_loader = DataLoader(test_5, batch_size=128, shuffle=False)

### CNN classifier

In [48]:
from src.model import SimpleCNNClassifier
from src.trainer import SimpleCNNTrainer

cnn = SimpleCNNClassifier(n_class=10).to(device)
optimizer = torch.optim.Adam(cnn.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()
trainer = SimpleCNNTrainer(cnn, optimizer, criterion, 5, device)

In [49]:
trainer.fit(26, train_loader, valid_loader)

epoch 0: 100%|██████████| 357/357 [00:01<00:00, 210.78batch/s, acc=0.937, loss=0.307]
val-epoch 0: 100%|██████████| 90/90 [00:00<00:00, 483.82it/s]


val_acc=0.926


epoch 5: 100%|██████████| 357/357 [00:01<00:00, 229.54batch/s, acc=0.979, loss=0.0663]
val-epoch 5: 100%|██████████| 90/90 [00:00<00:00, 482.11it/s]


val_acc=0.978


epoch 10: 100%|██████████| 357/357 [00:01<00:00, 239.67batch/s, acc=0.989, loss=0.0477]
val-epoch 10: 100%|██████████| 90/90 [00:00<00:00, 529.46it/s]


val_acc=0.981


epoch 15: 100%|██████████| 357/357 [00:01<00:00, 235.05batch/s, acc=1, loss=0.00894]   
val-epoch 15: 100%|██████████| 90/90 [00:00<00:00, 531.29it/s]


val_acc=0.980


epoch 20: 100%|██████████| 357/357 [00:01<00:00, 226.94batch/s, acc=1, loss=0.00328] 
val-epoch 20: 100%|██████████| 90/90 [00:00<00:00, 504.22it/s]


val_acc=0.982


epoch 25: 100%|██████████| 357/357 [00:01<00:00, 226.99batch/s, acc=1, loss=0.00126]   
val-epoch 25: 100%|██████████| 90/90 [00:00<00:00, 498.08it/s]

val_acc=0.981





In [50]:
trainer._valid(test_loader, True, epoch_id=0)

val-epoch 0: 100%|██████████| 23/23 [00:00<00:00, 511.08it/s]

val_acc=0.696





### CD-VAE zero-shot

In [51]:
vae = VAE(total_z_dim=32).to(device)
optimizer = torch.optim.Adam(vae.parameters(), lr=3e-4)
trainer = SimCLRTrainer(
    vae,
    optimizer,
    sim_fn="cosine",
    hyperparameter={"temperature": 0.1, "beta": 1.0, "alpha": [50, 50]},
    verbose_period=20,
    device=device,
)

In [52]:
trainer.fit(41, train_loader)

Epoch 0: 100%|██████████| 357/357 [00:04<00:00, 81.07batch/s, c_loss=0.219, s_loss=-2.72, vae_loss=80.3]


In [None]:
vae.eval()
for p in vae.parameters():
    p.requires_grad = False

In [None]:
model = torch.nn.Sequential(
    torch.nn.Linear(16, 256),
    torch.nn.ReLU(),
    torch.nn.Linear(256, 10)
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = torch.nn.CrossEntropyLoss()

In [None]:
EPOCHS = 26
VERBOSE_PERIOD = 5

for epoch in range(EPOCHS):
    verbose = (epoch % VERBOSE_PERIOD) == 0
    model.train()
    with tqdm(train_loader, unit="batch", disable=not verbose) as bar:
        bar.set_description(f"epoch {epoch}")
        for X_batch, y_batch, _ in bar:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            mu_c, logvar_c = vae.encode(X_batch)[:2]
            logits = model(mu_c)
            # logits = model(torch.cat([mu_c, logvar_c], dim=1))
            loss = criterion(logits, y_batch)
            loss.backward()
            optimizer.step()

            # update running stats
            acc = accurary(logits, y_batch)
            bar.set_postfix(loss=float(loss), acc=float(acc))
    model.eval()
    total_acc = 0
    with torch.no_grad():
        for X_batch, y_batch, _ in tqdm(valid_loader, disable=not verbose):
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            mu_c, logvar_c = vae.encode(X_batch)[:2]
            logits = model(mu_c)
            # logits = model(torch.cat([mu_c, logvar_c], dim=1))
            acc = accurary(logits, y_batch)
            total_acc += acc.item()
    if verbose:
        print("val_acc={:.3f}".format(total_acc / len(valid_loader)))
    
            


epoch 0: 100%|██████████| 357/357 [00:01<00:00, 246.67batch/s, acc=0.989, loss=0.16] 
100%|██████████| 90/90 [00:00<00:00, 448.36it/s]


val_acc=0.978


epoch 5: 100%|██████████| 357/357 [00:01<00:00, 255.41batch/s, acc=1, loss=0.0124]    
100%|██████████| 90/90 [00:00<00:00, 435.14it/s]


val_acc=0.981


epoch 10: 100%|██████████| 357/357 [00:01<00:00, 247.48batch/s, acc=1, loss=0.00291]    
100%|██████████| 90/90 [00:00<00:00, 366.06it/s]


val_acc=0.981


epoch 15: 100%|██████████| 357/357 [00:01<00:00, 240.38batch/s, acc=1, loss=0.009]      
100%|██████████| 90/90 [00:00<00:00, 388.93it/s]


val_acc=0.980


epoch 20: 100%|██████████| 357/357 [00:01<00:00, 250.85batch/s, acc=1, loss=0.00343]    
100%|██████████| 90/90 [00:00<00:00, 441.18it/s]


val_acc=0.981


epoch 25: 100%|██████████| 357/357 [00:01<00:00, 255.57batch/s, acc=0.989, loss=0.0364] 
100%|██████████| 90/90 [00:00<00:00, 414.56it/s]

val_acc=0.981





In [None]:
model.eval()
total_acc = 0
with torch.no_grad():
    for X_batch, y_batch, _ in tqdm(test_loader, disable=not verbose):
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        mu_c, logvar_c = vae.encode(X_batch)[:2]
        logits = model(mu_c)
        # logits = model(torch.cat([mu_c, logvar_c], dim=1))
        acc = accurary(logits, y_batch)
        total_acc += acc.item()
if verbose:
    print("val_acc={:.3f}".format(total_acc / len(test_loader)))

100%|██████████| 23/23 [00:00<00:00, 386.20it/s]

val_acc=0.949



