In [1]:
from src.model import VAE
from src.trainer import SimCLRTrainer
from src.losses import accurary
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split, ConcatDataset


from tqdm import tqdm

%load_ext autoreload
%autoreload 2

In [2]:
from corruption_utils import corruptions

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
mnist = torchvision.datasets.MNIST("../data", train=True, download=True)

In [5]:
majority = list(filter(lambda elem: elem[1] != 5, mnist)) # not 5
minority = list(filter(lambda elem: elem[1] == 5, mnist)) # 5

In [6]:
from src.utils import CMNISTGenerator, CMNIST

cmnist_generator = CMNISTGenerator(
    majority, # assign rich styles to non-5 digits
    {
        corruptions.identity: 0.3,
        corruptions.stripe: 0.3,
        corruptions.zigzag: 0.2,
        corruptions.canny_edges: 0.2,
    },
)
cmnist = CMNIST(
    cmnist_generator, 
    transforms.Compose([transforms.ToTensor(), lambda img: img / 255.0])
)


Generating dataset: 100%|██████████| 54579/54579 [00:07<00:00, 7237.91item/s]


In [7]:
train_5, test_5 = random_split(minority, [2500, len(minority) - 2500])

In [8]:
train_5_generator = CMNISTGenerator(
    train_5,
    {
        corruptions.identity: 1.0
    },
)
train_5 = CMNIST(
    train_5_generator, 
    transforms.Compose([transforms.ToTensor(), lambda img: img / 255.0])
)

Generating dataset: 100%|██████████| 2500/2500 [00:00<00:00, 62471.39item/s]


In [9]:
train = ConcatDataset([cmnist, train_5])
train, valid = random_split(train, [46000, len(train) - 46000])

In [10]:
test_5_generator = CMNISTGenerator(
    test_5,
    {
        corruptions.stripe: 0.3,
        corruptions.zigzag: 0.3,
        corruptions.canny_edges: 0.4,
        # corruptions.dotted_line: 0.3,
        # corruptions.translate: 0.3,
        # corruptions.shear: 0.4
    },
)
test_5 = CMNIST(
    test_5_generator, 
    transforms.Compose([transforms.ToTensor(), lambda img: img / 255.0])
)

Generating dataset: 100%|██████████| 2921/2921 [00:00<00:00, 4629.46item/s]


In [12]:
train_loader = DataLoader(train, batch_size=128, shuffle=True)
valid_loader = DataLoader(valid, batch_size=128, shuffle=False)
test_loader = DataLoader(test_5, batch_size=128, shuffle=False)

### CNN classifier

In [30]:
from src.model import SimpleCNNClassifier
from src.trainer import SimpleCNNTrainer

cnn = SimpleCNNClassifier(n_class=10).to(device)
optimizer = torch.optim.Adam(cnn.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()
trainer = SimpleCNNTrainer(cnn, optimizer, criterion, 2, device)

In [31]:
trainer.fit(21, train_loader, valid_loader)

epoch 0: 100%|██████████| 360/360 [00:02<00:00, 162.56batch/s, acc=0.979, loss=0.242]
val-epoch 0: 100%|██████████| 87/87 [00:00<00:00, 396.38it/s]


val_acc=0.920


epoch 2: 100%|██████████| 360/360 [00:01<00:00, 243.48batch/s, acc=1, loss=0.0396]    
val-epoch 2: 100%|██████████| 87/87 [00:00<00:00, 530.48it/s]


val_acc=0.964


epoch 4: 100%|██████████| 360/360 [00:01<00:00, 239.81batch/s, acc=0.958, loss=0.142] 
val-epoch 4: 100%|██████████| 87/87 [00:00<00:00, 508.45it/s]


val_acc=0.973


epoch 6: 100%|██████████| 360/360 [00:01<00:00, 237.47batch/s, acc=1, loss=0.0409]    
val-epoch 6: 100%|██████████| 87/87 [00:00<00:00, 514.43it/s]


val_acc=0.976


epoch 8: 100%|██████████| 360/360 [00:01<00:00, 239.88batch/s, acc=1, loss=0.03]      
val-epoch 8: 100%|██████████| 87/87 [00:00<00:00, 524.09it/s]


val_acc=0.977


epoch 10: 100%|██████████| 360/360 [00:01<00:00, 238.45batch/s, acc=0.979, loss=0.0292]
val-epoch 10: 100%|██████████| 87/87 [00:00<00:00, 521.69it/s]


val_acc=0.980


epoch 12: 100%|██████████| 360/360 [00:01<00:00, 238.31batch/s, acc=0.979, loss=0.0824]
val-epoch 12: 100%|██████████| 87/87 [00:00<00:00, 504.30it/s]


val_acc=0.980


epoch 14: 100%|██████████| 360/360 [00:01<00:00, 229.07batch/s, acc=1, loss=0.0204]    
val-epoch 14: 100%|██████████| 87/87 [00:00<00:00, 388.40it/s]


val_acc=0.981


epoch 16: 100%|██████████| 360/360 [00:01<00:00, 233.64batch/s, acc=1, loss=0.013]     
val-epoch 16: 100%|██████████| 87/87 [00:00<00:00, 511.70it/s]


val_acc=0.980


epoch 18: 100%|██████████| 360/360 [00:01<00:00, 238.13batch/s, acc=1, loss=0.00437]    
val-epoch 18: 100%|██████████| 87/87 [00:00<00:00, 545.28it/s]


val_acc=0.981


epoch 20: 100%|██████████| 360/360 [00:01<00:00, 220.53batch/s, acc=1, loss=0.00308]   
val-epoch 20: 100%|██████████| 87/87 [00:00<00:00, 527.04it/s]

val_acc=0.981





In [32]:
trainer._valid(test_loader, True, epoch_id=0)

val-epoch 0: 100%|██████████| 23/23 [00:00<00:00, 287.46it/s]

val_acc=0.610





### CD-VAE zero-shot

In [16]:
vae = VAE(total_z_dim=32).to(device)
optimizer = torch.optim.Adam(vae.parameters(), lr=3e-4)
trainer = SimCLRTrainer(
    vae,
    optimizer,
    sim_fn="cosine",
    hyperparameter={"temperature": 0.1, "beta": 100},
    verbose_period=5,
    device=device,
)

In [17]:
trainer.fit(31, train_loader)

Epoch 0: 100%|██████████| 360/360 [00:04<00:00, 82.54batch/s, c_loss=0.208, vae_loss=88.2]  
Epoch 5: 100%|██████████| 360/360 [00:03<00:00, 108.06batch/s, c_loss=0.021, vae_loss=50.2] 
Epoch 10: 100%|██████████| 360/360 [00:03<00:00, 101.63batch/s, c_loss=0.0191, vae_loss=29.3] 
Epoch 15: 100%|██████████| 360/360 [00:03<00:00, 110.22batch/s, c_loss=0.0113, vae_loss=24]   
Epoch 20: 100%|██████████| 360/360 [00:03<00:00, 110.40batch/s, c_loss=0.0178, vae_loss=22.5] 
Epoch 25: 100%|██████████| 360/360 [00:03<00:00, 107.47batch/s, c_loss=0.0274, vae_loss=18.8] 
Epoch 30: 100%|██████████| 360/360 [00:03<00:00, 97.23batch/s, c_loss=0.00734, vae_loss=19.6] 


In [18]:
vae.eval()
for p in vae.parameters():
    p.requires_grad = False

In [26]:
model = torch.nn.Sequential(
    torch.nn.Linear(16, 256),
    torch.nn.ReLU(),
    torch.nn.Linear(256, 10)
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = torch.nn.CrossEntropyLoss()

In [27]:
EPOCHS = 21
VERBOSE_PERIOD = 2

for epoch in range(EPOCHS):
    verbose = (epoch % VERBOSE_PERIOD) == 0
    model.train()
    with tqdm(train_loader, unit="batch", disable=not verbose) as bar:
        bar.set_description(f"epoch {epoch}")
        for X_batch, y_batch in bar:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            mu_c, logvar_c = vae.encode(X_batch)[:2]
            logits = model(mu_c)
            # logits = model(torch.cat([mu_c, logvar_c], dim=1))
            loss = criterion(logits, y_batch)
            loss.backward()
            optimizer.step()

            # update running stats
            acc = accurary(logits, y_batch)
            bar.set_postfix(loss=float(loss), acc=float(acc))
    model.eval()
    total_acc = 0
    with torch.no_grad():
        for X_batch, y_batch in tqdm(valid_loader, disable=not verbose):
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            mu_c, logvar_c = vae.encode(X_batch)[:2]
            logits = model(mu_c)
            # logits = model(torch.cat([mu_c, logvar_c], dim=1))
            acc = accurary(logits, y_batch)
            total_acc += acc.item()
    if verbose:
        print("val_acc={:.3f}".format(total_acc / len(valid_loader)))
    
            


epoch 0: 100%|██████████| 360/360 [00:02<00:00, 177.23batch/s, acc=1, loss=0.0584]    
100%|██████████| 87/87 [00:00<00:00, 303.08it/s]


val_acc=0.987


epoch 2: 100%|██████████| 360/360 [00:01<00:00, 248.89batch/s, acc=1, loss=0.0125]   
100%|██████████| 87/87 [00:00<00:00, 327.04it/s]


val_acc=0.987


epoch 4: 100%|██████████| 360/360 [00:01<00:00, 242.04batch/s, acc=1, loss=0.00191]   
100%|██████████| 87/87 [00:00<00:00, 332.28it/s]


val_acc=0.988


epoch 6: 100%|██████████| 360/360 [00:01<00:00, 260.59batch/s, acc=1, loss=0.000627]   
100%|██████████| 87/87 [00:00<00:00, 343.43it/s]


val_acc=0.988


epoch 8: 100%|██████████| 360/360 [00:01<00:00, 245.42batch/s, acc=1, loss=0.00601]    
100%|██████████| 87/87 [00:00<00:00, 335.45it/s]


val_acc=0.988


epoch 10: 100%|██████████| 360/360 [00:01<00:00, 253.98batch/s, acc=1, loss=0.00108]    
100%|██████████| 87/87 [00:00<00:00, 321.03it/s]


val_acc=0.987


epoch 12: 100%|██████████| 360/360 [00:01<00:00, 255.73batch/s, acc=1, loss=0.000111]
100%|██████████| 87/87 [00:00<00:00, 322.50it/s]


val_acc=0.987


epoch 14: 100%|██████████| 360/360 [00:01<00:00, 258.55batch/s, acc=1, loss=4.23e-5] 
100%|██████████| 87/87 [00:00<00:00, 337.63it/s]


val_acc=0.987


epoch 16: 100%|██████████| 360/360 [00:01<00:00, 260.03batch/s, acc=1, loss=4.24e-5] 
100%|██████████| 87/87 [00:00<00:00, 357.70it/s]


val_acc=0.987


epoch 18: 100%|██████████| 360/360 [00:01<00:00, 220.85batch/s, acc=1, loss=1.64e-5] 
100%|██████████| 87/87 [00:00<00:00, 325.23it/s]


val_acc=0.987


epoch 20: 100%|██████████| 360/360 [00:01<00:00, 252.30batch/s, acc=1, loss=0.000213]
100%|██████████| 87/87 [00:00<00:00, 334.04it/s]

val_acc=0.987





In [29]:
model.eval()
total_acc = 0
with torch.no_grad():
    for X_batch, y_batch in tqdm(test_loader, disable=not verbose):
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        mu_c, logvar_c = vae.encode(X_batch)[:2]
        logits = model(mu_c)
        # logits = model(torch.cat([mu_c, logvar_c], dim=1))
        acc = accurary(logits, y_batch)
        total_acc += acc.item()
if verbose:
    print("val_acc={:.3f}".format(total_acc / len(test_loader)))

100%|██████████| 23/23 [00:00<00:00, 255.55it/s]

val_acc=0.714



