In [1]:
from src.model import VAE
from src.trainer import SimCLRTrainer
import torch
import torchvision
import torchvision.transforms as transforms


from torch.utils.data import DataLoader, random_split, ConcatDataset
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

from tqdm import tqdm

%load_ext autoreload
%autoreload 2

In [2]:
from corruption_utils import corruptions

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
mnist = torchvision.datasets.MNIST("../data", train=True, download=True)

In [5]:
majority = list(filter(lambda elem: elem[1] != 5, mnist)) # not 5
minority = list(filter(lambda elem: elem[1] == 5, mnist)) # 5

In [6]:
from src.utils import CMNISTGenerator, CMNIST

cmnist_generator = CMNISTGenerator(
    majority, # assign rich styles to non-5 digits
    {
        corruptions.identity: 0.2,
        corruptions.stripe: 0.3,
        corruptions.zigzag: 0.3,
        corruptions.canny_edges: 0.2,
    },
)
cmnist = CMNIST(
    cmnist_generator, 
    transforms.Compose([transforms.ToTensor(), lambda img: img / 255.0])
)


Generating dataset: 100%|██████████| 54579/54579 [00:08<00:00, 6165.67item/s]


In [7]:
train_5, test_5 = random_split(minority, [2500, len(minority) - 2500])

In [8]:
train_5_generator = CMNISTGenerator(
    train_5,
    {
        corruptions.identity: 1.0
    },
)
train_5 = CMNIST(
    train_5_generator, 
    transforms.Compose([transforms.ToTensor(), lambda img: img / 255.0])
)

Generating dataset: 100%|██████████| 2500/2500 [00:00<00:00, 54334.30item/s]


In [9]:
train = ConcatDataset([cmnist, train_5])

In [10]:
test_5_generator = CMNISTGenerator(
    test_5,
    {
        corruptions.stripe: 0.3,
        corruptions.zigzag: 0.3,
        corruptions.canny_edges: 0.4,
    },
)
test_5 = CMNIST(
    test_5_generator, 
    transforms.Compose([transforms.ToTensor(), lambda img: img / 255.0])
)

Generating dataset: 100%|██████████| 2921/2921 [00:00<00:00, 4435.50item/s]


In [11]:
train_loader = DataLoader(train, batch_size=128, shuffle=True)
test_loader = DataLoader(test_5, batch_size=128, shuffle=False)

In [12]:
def accurary(logit: torch.Tensor, y: torch.Tensor):
    yh = logit.argmax(dim=1)
    return (yh == y).float().mean()

In [83]:
from src.model import SimpleCNNClassifier
from src.trainer import SimpleCNNTrainer

cnn = SimpleCNNClassifier(n_class=10).to(device)
optimizer = torch.optim.Adam(cnn.parameters(), lr=3e-4)
criterion = torch.nn.CrossEntropyLoss()
trainer = SimpleCNNTrainer(cnn, optimizer, criterion, 2, device)

In [84]:
trainer.fit(train_loader, 5)

epoch 0: 100%|██████████| 446/446 [00:02<00:00, 169.75batch/s, acc=0.966, loss=0.103] 
epoch 2: 100%|██████████| 446/446 [00:02<00:00, 203.51batch/s, acc=0.983, loss=0.0402]
epoch 4: 100%|██████████| 446/446 [00:02<00:00, 202.38batch/s, acc=1, loss=0.015]     


In [64]:
with torch.no_grad():
    cnn.eval()
    total_acc = 0
    for X_batch, y_batch in tqdm(test_loader):
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        logits = cnn(X_batch)
        acc = accurary(logits, y_batch)
        total_acc += acc.item()
print(total_acc / len(test_loader))

100%|██████████| 23/23 [00:00<00:00, 256.82it/s]

0.6973893642425537





In [74]:
model = torch.nn.Sequential(
    torch.nn.Linear(784, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 10)
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = torch.nn.CrossEntropyLoss()

In [75]:
for epoch in range(5):
    with tqdm(train_loader, unit="batch") as bar:
        bar.set_description(f"epoch {epoch}")
        for X_batch, y_batch in bar:
            X_batch, y_batch = X_batch.view(-1, 784).to(device), y_batch.to(device)
            optimizer.zero_grad()
            logits = model(X_batch)
            loss = criterion(logits, y_batch)
            loss.backward()
            optimizer.step()

            # update running stats
            acc = accurary(logits, y_batch)
            bar.set_postfix(loss=float(loss), acc=float(acc))

epoch 0: 100%|██████████| 446/446 [00:02<00:00, 205.41batch/s, acc=0.916, loss=0.364]
epoch 1: 100%|██████████| 446/446 [00:02<00:00, 205.82batch/s, acc=0.908, loss=0.365]
epoch 2: 100%|██████████| 446/446 [00:02<00:00, 204.99batch/s, acc=0.916, loss=0.308]
epoch 3: 100%|██████████| 446/446 [00:02<00:00, 213.09batch/s, acc=0.857, loss=0.372]
epoch 4: 100%|██████████| 446/446 [00:02<00:00, 209.88batch/s, acc=0.899, loss=0.271]


In [77]:
with torch.no_grad():
    total_acc = 0
    for X_batch, y_batch in tqdm(test_loader):
        X_batch, y_batch = X_batch.view(-1, 784).to(device), y_batch.to(device)
        logits = model(X_batch)
        loss = criterion(logits, y_batch)
        acc = accurary(logits, y_batch)
        total_acc += acc.item()
print(total_acc / len(test_loader))

100%|██████████| 23/23 [00:00<00:00, 383.25it/s]

0.28125323549560877





In [65]:
vae = VAE(z_dim=16).to(device)
optimizer = torch.optim.Adam(vae.parameters(), lr=3e-4)
trainer = SimCLRTrainer(
    vae,
    optimizer,
    sim_fn="cosine",
    hyperparameter={"temperature": 0.1, "beta": 100},
    verbose_period=5,
    device=device,
)

In [66]:
trainer.fit(train_loader, 31)

Epoch 0: 100%|██████████| 446/446 [00:05<00:00, 86.29batch/s, c_loss=0.157, vae_loss=75.3]  
Epoch 5: 100%|██████████| 446/446 [00:04<00:00, 103.54batch/s, c_loss=0.0217, vae_loss=37.6]
Epoch 10: 100%|██████████| 446/446 [00:07<00:00, 58.09batch/s, c_loss=0.0254, vae_loss=25.8]  
Epoch 15: 100%|██████████| 446/446 [00:10<00:00, 43.43batch/s, c_loss=0.0131, vae_loss=22.3]  
Epoch 20: 100%|██████████| 446/446 [00:04<00:00, 110.90batch/s, c_loss=0.0116, vae_loss=21.7] 
Epoch 25: 100%|██████████| 446/446 [00:04<00:00, 108.34batch/s, c_loss=0.0113, vae_loss=16.4] 
Epoch 30: 100%|██████████| 446/446 [00:04<00:00, 105.16batch/s, c_loss=0.0029, vae_loss=17.5] 


In [78]:
model = torch.nn.Sequential(
    torch.nn.Linear(16, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 10)
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = torch.nn.CrossEntropyLoss()

In [79]:
vae.eval()
for epoch in range(5):
    with tqdm(train_loader, unit="batch") as bar:
        bar.set_description(f"epoch {epoch}")
        for X_batch, y_batch in bar:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            mu_c = vae.encode(X_batch)[0]
            logits = model(mu_c)
            loss = criterion(logits, y_batch)
            loss.backward()
            optimizer.step()

            # update running stats
            acc = accurary(logits, y_batch)
            bar.set_postfix(loss=float(loss), acc=float(acc))
            


epoch 0: 100%|██████████| 446/446 [00:02<00:00, 157.03batch/s, acc=1, loss=0.0974]   
epoch 1: 100%|██████████| 446/446 [00:02<00:00, 185.65batch/s, acc=1, loss=0.0183]    
epoch 2: 100%|██████████| 446/446 [00:02<00:00, 190.69batch/s, acc=1, loss=0.0201]   
epoch 3: 100%|██████████| 446/446 [00:02<00:00, 188.44batch/s, acc=1, loss=0.00615]   
epoch 4: 100%|██████████| 446/446 [00:02<00:00, 181.65batch/s, acc=1, loss=0.00535]


In [82]:
with torch.no_grad():
    total_acc = 0
    for X_batch, y_batch in tqdm(test_loader):
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        mu_c = vae.encode(X_batch)[0]
        logits = model(mu_c)
        acc = accurary(logits, y_batch)
        total_acc += acc.item()
print(total_acc / len(test_loader))

100%|██████████| 23/23 [00:00<00:00, 223.99it/s]

0.6503234998039578



