<a href="https://colab.research.google.com/github/harrywinks/leaf_classification/blob/master/BYOL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# BYOL
PyTorch implementation of "Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning" by J.B. Grill et al. Added support for PyTorch <= 1.5.0 and practical dataset implementation (CIFAR-10).



### Clone code

In [None]:
!git clone https://github.com/spijkervet/byol --recurse-submodules -j8
%cd byol

Cloning into 'byol'...
remote: Enumerating objects: 65, done.[K
remote: Counting objects: 100% (65/65), done.[K
remote: Compressing objects: 100% (43/43), done.[K
remote: Total 65 (delta 25), reused 46 (delta 13), pack-reused 0[K
Unpacking objects: 100% (65/65), done.
Submodule 'modules/byol' (https://github.com/spijkervet/byol-pytorch) registered for path 'modules/byol'
Cloning into '/content/byol/modules/byol'...
remote: Enumerating objects: 125, done.        
remote: Counting objects: 100% (125/125), done.        
remote: Compressing objects: 100% (82/82), done.        
remote: Total 125 (delta 54), reused 93 (delta 33), pack-reused 0        
Receiving objects: 100% (125/125), 48.47 KiB | 291.00 KiB/s, done.
Resolving deltas: 100% (54/54), done.
Submodule path 'modules/byol': checked out '8cc35a298df67e6f6535bef9ab94992a186f4aca'
/content/byol


### Import modules

In [None]:
import os
import argparse
import torch
from torch.utils.tensorboard import SummaryWriter
from torchvision import models, datasets
import numpy as np
from collections import defaultdict

# distributed training
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

In [None]:
from modules import BYOL
from modules.transformations import TransformsSimCLR

## Pre-training

### Define hyperparameters

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument("--image_size", default=224, type=int, help="Image size")
parser.add_argument(
    "--learning_rate", default=3e-4, type=float, help="Initial learning rate."
)
parser.add_argument(
    "--batch_size", default=192, type=int, help="Batch size for training."
)
parser.add_argument(
    "--num_epochs", default=100, type=int, help="Number of epochs to train for."
)
parser.add_argument(
    "--resnet_version", default="resnet18", type=str, help="ResNet version."
)
parser.add_argument(
    "--checkpoint_epochs",
    default=10,
    type=int,
    help="Number of epochs between checkpoints/summaries.",
)
parser.add_argument(
    "--dataset_dir",
    default="./datasets",
    type=str,
    help="Directory where dataset is stored.",
)
parser.add_argument(
    "--num_workers",
    default=8,
    type=int,
    help="Number of data loading workers (caution with nodes!)",
)
parser.add_argument(
    "--nodes", default=1, type=int, help="Number of nodes",
)
parser.add_argument("--gpus", default=1, type=int, help="number of gpus per node")
parser.add_argument("--nr", default=0, type=int, help="ranking within the nodes")

# colab work-around
# args = parser.parse_args()
args = parser.parse_args(args=[])

### Pre-train ResNet

In [None]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
torch.manual_seed(0)

# dataset
train_dataset = datasets.CIFAR10(
    args.dataset_dir,
    download=True,
    transform=TransformsSimCLR(size=args.image_size), # paper 224
)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    drop_last=True,
    num_workers=args.num_workers,
)

# model
if args.resnet_version == "resnet18":
    resnet = models.resnet18(pretrained=False)
elif args.resnet_version == "resnet50":
    resnet = models.resnet50(pretrained=False)
else:
    raise NotImplementedError("ResNet not implemented")

model = BYOL(resnet, image_size=args.image_size, hidden_layer="avgpool")
model = model.to(device)

# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

# solver
global_step = 0
for epoch in range(args.num_epochs):
    metrics = defaultdict(list)
    for step, ((x_i, x_j), _) in enumerate(train_loader):
        x_i = x_i.to(device)
        x_j = x_j.to(device)

        loss = model(x_i, x_j)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        model.update_moving_average()  # update moving average of target encoder

        if step % 1 == 0:
            print(f"Step [{step}/{len(train_loader)}]:\tLoss: {loss.item()}")

        metrics["Loss/train"].append(loss.item())
        global_step += 1

    # write metrics to TensorBoard
    print(f"Epoch [{epoch}/{args.num_epochs}]: " + "\t".join([f"{k}: {np.array(v).mean()}" for k, v in metrics.items()]))

    if epoch % args.checkpoint_epochs == 0:
        print(f"Saving model at epoch {epoch}")
        torch.save(resnet.state_dict(), f"./model-{epoch}.pt")


# save your improved network
torch.save(resnet.state_dict(), "./model-final.pt")

Files already downloaded and verified
Step [0/260]:	Loss: 3.972365379333496
Step [1/260]:	Loss: 2.211014747619629
Step [2/260]:	Loss: 1.7506556510925293
Step [3/260]:	Loss: 1.6225242614746094
Step [4/260]:	Loss: 1.5917638540267944


KeyboardInterrupt: ignored

## Linear evaluation

In [None]:
from process_features import get_features, create_data_loaders_from_arrays

parser = argparse.ArgumentParser()
parser.add_argument("--image_size", default=224, type=int, help="Image size")
parser.add_argument(
    "--learning_rate", default=3e-3, type=float, help="Initial learning rate."
)
parser.add_argument(
    "--batch_size", default=768, type=int, help="Batch size for training."
)
parser.add_argument(
    "--num_epochs", default=300, type=int, help="Number of epochs to train for."
)
parser.add_argument(
    "--checkpoint_epochs",
    default=10,
    type=int,
    help="Number of epochs between checkpoints/summaries.",
)
parser.add_argument(
    "--dataset_dir",
    default="./datasets",
    type=str,
    help="Directory where dataset is stored.",
)
parser.add_argument(
    "--num_workers",
    default=8,
    type=int,
    help="Number of data loading workers (caution with nodes!)",
)

args = parser.parse_args(args=[])

args.model_path = "model-final.pt"

### Data loaders

In [None]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

# data loaders
train_dataset = datasets.CIFAR10(
    args.dataset_dir,
    download=True,
    transform=TransformsSimCLR(size=args.image_size).test_transform,
)

test_dataset = datasets.CIFAR10(
    args.dataset_dir,
    train=False,
    download=True,
    transform=TransformsSimCLR(size=args.image_size).test_transform,
)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    drop_last=True,
    num_workers=args.num_workers,
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=args.batch_size,
    drop_last=True,
    num_workers=args.num_workers,
)

Files already downloaded and verified
Files already downloaded and verified


### Load pre-trained model, loss and optimizer

In [None]:
# pre-trained model
resnet = models.resnet50()
resnet.load_state_dict(torch.load(args.model_path, map_location=device))
resnet = resnet.to(device)

num_features = list(resnet.children())[-1].in_features

# throw away fc layer
resnet = nn.Sequential(*list(resnet.children())[:-1])
n_classes = 10 # CIFAR-10 has 10 classes

# fine-tune model
logreg = nn.Sequential(nn.Linear(num_features, n_classes))
logreg = logreg.to(device)

# loss / optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=logreg.parameters(), lr=args.learning_rate)

FileNotFoundError: ignored

### Compute features
The data only needs to go through the pre-trained model once, since there is no backprop. This saves computation time when training the fine-tuned model (logistic regression model)

In [None]:
# compute features (only needs to be done once, since it does not backprop during fine-tuning)
if not os.path.exists("features.p"):
    print("### Creating features from pre-trained model ###")
    (train_X, train_y, test_X, test_y) = get_features(
        resnet, train_loader, test_loader, device
    )
    pickle.dump(
        (train_X, train_y, test_X, test_y), open("features.p", "wb"), protocol=4
    )
else:
    print("### Loading features ###")
    (train_X, train_y, test_X, test_y) = pickle.load(open("features.p", "rb"))


train_loader, test_loader = create_data_loaders_from_arrays(
    train_X, train_y, test_X, test_y, 2048 
)

### Train fine-tuned model

In [None]:
# Train fine-tuned model
for epoch in range(args.num_epochs):
    metrics = defaultdict(list)
    for step, (h, y) in enumerate(train_loader):
        h = h.to(device)
        y = y.to(device)

        outputs = logreg(h)

        loss = criterion(outputs, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # calculate accuracy and save metrics
        accuracy = (outputs.argmax(1) == y).sum().item() / y.size(0)
        metrics["Loss/train"].append(loss.item())
        metrics["Accuracy/train"].append(accuracy)

    print(f"Epoch [{epoch}/{args.num_epochs}]: " + "\t".join([f"{k}: {np.array(v).mean()}" for k, v in metrics.items()]))

### Test fine-tuned model

In [None]:
# Test fine-tuned model
print("### Calculating final testing performance ###")
metrics = defaultdict(list)
for step, (h, y) in enumerate(test_loader):
    h = h.to(device)
    y = y.to(device)

    outputs = logreg(h)

    # calculate accuracy and save metrics
    accuracy = (outputs.argmax(1) == y).sum().item() / y.size(0)
    metrics["Accuracy/test"].append(accuracy)

print(f"Final test performance: " + "\t".join([f"{k}: {np.array(v).mean()}" for k, v in metrics.items()]))