In [1]:
import torch
from torchvision import datasets
from torchvision import transforms
from tqdm import tqdm
import wandb
from dotenv import dotenv_values
import os
import json

In [2]:
envs = ["secret.env"]

for fenv in envs:
    file = os.path.join("env", fenv)
    config = dotenv_values(file)  # load sensitive variables
    print(config.keys())
    for c, v in config.items():
        os.environ[c] = v

odict_keys(['WANDB_API_KEY', 'WANDB_PROJECT'])


In [3]:
wandb_key = os.environ["WANDB_API_KEY"]
wandb.login(key=wandb_key)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mwilber-quito[0m ([33mdeepsat[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [4]:
class WANDBConfig:

    def __init__(self, job_type: str, config_path: str):
        self.job_type = job_type
        self.config_path = config_path

    def __enter__(self):
        config = {}
        with open(self.config_path) as f:
            config = json.load(f)
            if config["accelerate"]:
                device = "cuda" if torch.cuda.is_available() else "cpu"
                config["device"] = device
            else:
                config["device"] = "cpu"

        wandb.init(job_type=self.job_type, config=config)

    def __exit__(self, exc_type, exc_val, exc_tb):
        wandb.finish()

In [5]:
class ClearCache:
    def __enter__(self):
        torch.cuda.empty_cache()

    def __exit__(self, exc_type, exc_val, exc_tb):
        torch.cuda.empty_cache()

In [6]:
def dataloader():
    # Transforms images to a PyTorch Tensor
    tensor_transform = transforms.ToTensor()

    # Download the MNIST Dataset
    dataset = datasets.MNIST(
        root="./data", train=True, download=True, transform=tensor_transform
    )

    generator = torch.Generator().manual_seed(42)
    train_ae_ds, val_ae_ds, train_classifier_ds, val_classifier_ds = (
        torch.utils.data.random_split(
            dataset, [0.6, 0.1, 0.2, 0.1], generator=generator
        )
    )

    # DataLoader is used to load the dataset
    # for training
    train_ae_loader = torch.utils.data.DataLoader(
        dataset=train_ae_ds, batch_size=wandb.config.batch_size, shuffle=True
    )

    val_ae_loader = torch.utils.data.DataLoader(
        dataset=val_ae_ds, batch_size=wandb.config.batch_size, shuffle=True
    )

    train_classifier_loader = torch.utils.data.DataLoader(
        dataset=train_classifier_ds, batch_size=wandb.config.batch_size, shuffle=True
    )

    val_classifier_loader = torch.utils.data.DataLoader(
        dataset=val_classifier_ds, batch_size=wandb.config.batch_size, shuffle=True
    )
    return {
        "ae_train": train_ae_loader,
        "ae_val": val_ae_loader,
        "classifier_train": train_classifier_loader,
        "classifier_val": val_classifier_loader,
    }

In [7]:
class AE(torch.nn.Module):
    def __init__(self):
        super().__init__()

        # Building an linear encoder with Linear
        # layer followed by Relu activation function
        # 784 ==> 9
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(28 * 28, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 36),
            torch.nn.ReLU(),
            torch.nn.Linear(36, 18),
            torch.nn.ReLU(),
            torch.nn.Linear(18, 10),
        )

        # Building an linear decoder with Linear
        # layer followed by Relu activation function
        # The Sigmoid activation function
        # outputs the value between 0 and 1
        # 9 ==> 784
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(10, 18),
            torch.nn.ReLU(),
            torch.nn.Linear(18, 36),
            torch.nn.ReLU(),
            torch.nn.Linear(36, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 28 * 28),
            torch.nn.Sigmoid(),
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

In [8]:
class SSLClassifier(torch.nn.Module):

    def __init__(self, ae: AE):

        super().__init__()

        self.encoder = ae.encoder
        for params in self.encoder.parameters():
            params.requires_grad = False

        # Defines model for classifying digits
        self.classifier = torch.nn.Sequential(
            torch.nn.ReLU(),
            torch.nn.Linear(10, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 36),
            torch.nn.ReLU(),
            torch.nn.Linear(36, 18),
            torch.nn.ReLU(),
            torch.nn.Linear(18, 10),
            torch.nn.Softmax(),
        )

    def forward(self, x):
        x = self.encoder(x)
        logits = self.classifier(x)
        return logits

In [9]:
class SSLClassifier2(torch.nn.Module):

    def __init__(self, ae: AE):

        super().__init__()

        self.encoder = ae.encoder

        # Defines model for classifying digits
        self.classifier = torch.nn.Sequential(
            torch.nn.ReLU(),
            torch.nn.Linear(10, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 36),
            torch.nn.ReLU(),
            torch.nn.Linear(36, 18),
            torch.nn.ReLU(),
            torch.nn.Linear(18, 10),
            torch.nn.Softmax(),
        )

    def forward(self, x):
        x = self.encoder(x)
        logits = self.classifier(x)
        return logits

In [10]:
def train_step(model, optimizer, loss_fn, loader):

    model.train()

    batch_loss = 0.0
    batch_acc = 0.0

    for image, label in loader:

        # Reshaping the image to (-1, 784)
        image = image.reshape(-1, 28 * 28)

        # Moving tensors to device
        image = image.to(wandb.config.device)
        label = label.to(wandb.config.device)

        # Output of classifier
        pred = model(image)

        # Calculating the loss function
        loss = loss_fn(pred, label)

        # The gradients are set to zero,
        # the gradient is computed and stored.
        # .step() performs parameter update
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_loss += loss.item() * image.size(0)
        batch_acc += torch.sum(torch.argmax(pred, dim=1) == label).item()

    batch_loss = batch_loss / len(loader.dataset)
    batch_acc = batch_acc / len(loader.dataset)

    return batch_loss, batch_acc


def val_step(model, loss_fn, loader):

    model.eval()

    batch_loss = 0.0
    batch_acc = 0.0

    for image, label in loader:

        # Reshaping the image to (-1, 784)
        image = image.reshape(-1, 28 * 28)

        # Moving tensors to device
        image = image.to(wandb.config.device)
        label = label.to(wandb.config.device)

        # Output of Autoencoder
        pred = model(image)

        # Calculating the loss function
        loss = loss_fn(pred, label)
        batch_loss += loss.item() * image.size(0)

        # Calculating the acc
        batch_acc += torch.sum(torch.argmax(pred, dim=1) == label).item()

    batch_loss = batch_loss / len(loader.dataset)
    batch_acc = batch_acc / len(loader.dataset)

    return batch_loss, batch_acc


def train(model, optimizer, loss_fn, train_loader, val_loader):

    model = model.to(wandb.config.device)

    for epoch in tqdm(range(wandb.config.epochs + 1)):
        train_loss, train_acc = train_step(model, optimizer, loss_fn, train_loader)
        val_loss, val_acc = val_step(model, loss_fn, val_loader)
        wandb.log(
            {
                "train_loss": train_loss,
                "train_acc": train_acc,
                "val_loss": val_loss,
                "val_acc": val_acc,
                "epoch": epoch,
            }
        )

In [11]:
config_path = "config/classifier-ssl.json"
job_type = "classifier-ssl"

with WANDBConfig(job_type, config_path) as cf, ClearCache() as cc:

    # Dataloader
    loaders = dataloader()
    train_loader = loaders["classifier_train"]
    val_loader = loaders["classifier_val"]

    # Loads AutoEncoder weights
    ae = AE()
    ae_model_state = torch.load(wandb.config.trained_autoencoder_path)
    ae.load_state_dict(ae_model_state)

    # Instanciate new digit classifier model
    model = SSLClassifier(ae) if wandb.config.freeze_encoder else SSLClassifier2(ae)

    # Validation using Cross Entropy Loss function
    loss_fn = torch.nn.CrossEntropyLoss()

    # Using an Adam Optimizer
    optimizer = torch.optim.Adam(
        model.parameters(), lr=wandb.config.lr, weight_decay=1e-8
    )
    train(model, optimizer, loss_fn, train_loader, val_loader)

  return self._call_impl(*args, **kwargs)
100%|██████████| 41/41 [06:57<00:00, 10.19s/it]


0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
train_acc,▁▃▅██████▃███▅▃██▇█▇▄▆█████▅█▆▇████▇██▇▇
train_loss,█▄▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
val_acc,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss,▁▁▄▅▅▄▃▃▃▄▃▂▄▂▃▂▅█▂▃▃▃▁▃▅▄▃▆▄▅▇▄█▃▄▅▅▅▃▇

0,1
epoch,40.0
train_acc,0.11367
train_loss,2.30079
val_acc,0.111
val_loss,2.30134


In [12]:
torch.save(model.state_dict(), "pth/classifier-ssl.pth")