In [3]:
gpu_info = !nvidia-smi -i 0
gpu_info = '\n'.join(gpu_info)
print(gpu_info)

from datetime import datetime
from functools import partial
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models import resnet
from tqdm import tqdm
import argparse
import json
import math
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F 
from configs import model_config


from main_utils import get_model, get_optimizer
from cifar_dataset import train_dataloader, train_val_dataloader, test_dataloader

Wed May 17 14:20:13 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 517.00       Driver Version: 517.00       CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name            TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ... WDDM  | 00000000:01:00.0  On |                  N/A |
| N/A   57C    P8    12W /  N/A |    499MiB /  6144MiB |     15%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [4]:
import torch.nn as nn
import torch
import torchvision.utils
import torchvision
import torch.nn.functional as F

import copy

class EMA():
    def __init__(self, alpha):
        super().__init__()
        self.alpha = alpha

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.alpha + (1 - self.alpha) * new

#create the Siamese Neural Network
class BYOLNetwork(nn.Module):

    def __init__(self, in_features=512, hidden_size=4096, embedding_size=256, projection_size=256, projection_hidden_size=2048, batch_norm_mlp=True):
        super(BYOLNetwork, self).__init__()
        self.online = self.get_rep_and_proj(in_features, embedding_size, hidden_size, batch_norm_mlp)
        self.predictor = self.get_cnn_block(projection_size, projection_size, projection_hidden_size)
        self.target = self.get_target()
        self.ema = EMA(0.99)
    
    @torch.no_grad()
    def get_target(self):
        return copy.deepcopy(self.online)

    def get_cnn_block(self, dim, embedding_size=256, hidden_size=2048, batch_norm_mlp=False):
        norm = nn.BatchNorm1d(hidden_size) #if batch_norm_mlp else nn.Identity()
        return nn.Sequential(
            nn.Linear(dim, hidden_size),
            norm,
            nn.ReLU(inplace=True),
            nn.Linear(hidden_size, embedding_size)
        )

    def get_rep_and_proj(self, in_features, embedding_size, hidden_size, batch_norm_mlp):
        self.backbone = torchvision.models.resnet50(num_classes=hidden_size)  # Output of last linear layer
        # The MLP for g(.) consists of Linear->ReLU->Linear
        self.backbone.fc = nn.Sequential(
            self.backbone.fc,  # Linear(ResNet output, 4*hidden_dim)
            nn.ReLU(inplace=True),
            nn.Linear(hidden_size, in_features)
        )
        proj = self.get_cnn_block(in_features, embedding_size, hidden_size=hidden_size, batch_norm_mlp=batch_norm_mlp)
        return nn.Sequential(self.backbone, proj)

    @torch.no_grad()
    def update_moving_average(self):
        for online_params, target_params in zip(self.online.parameters(), self.target.parameters()):
            old_weight, up_weight = target_params.data, online_params.data
            target_params.data = self.ema.update_average(old_weight, up_weight)
    def byol_loss(self, x, y):
        # L2 normalization
        x = F.normalize(x, dim=-1, p=2)
        y = F.normalize(y, dim=-1, p=2)
        loss = 2 - 2 * (x * y).sum(dim=-1)
        return loss

    def forward(self, x1, x2=None, return_embedding=False):
        if return_embedding or (x2 is None):
            return self.online(x1)

        # online projections: backbone + MLP projection
        x1_1 = self.online(x1)
        x1_2 = self.online(x2)

        # additional online's MLP head called predictor
        x1_1_pred = self.predictor(x1_1)
        x1_2_pred = self.predictor(x1_2)

        with torch.no_grad():
            # teacher processes the images and makes projections: backbone + MLP
            x2_1 = self.target(x1).detach_()
            x2_2 = self.target(x2).detach_()

        loss = (self.byol_loss(x1_1_pred, x2_1) + self.byol_loss(x1_2_pred, x2_2)).mean()

        return loss

# model = BYOLNetwork().cuda()

model, _, ckpt = get_model(name="byol", conf={}, resume=False)

In [5]:
model

BYOLNetwork(
  (backbone): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
     

In [6]:
# train for one epoch
def train(net, data_loader, train_optimizer, epoch):
    net.train()
    adjust_learning_rate(optimizer, epoch)

    total_loss, total_num, train_bar = 0.0, 0, tqdm(data_loader)
    for im_1, im_2 in train_bar:
        im_1, im_2 = im_1.cuda(non_blocking=True), im_2.cuda(non_blocking=True)

        loss = net(im_1, im_2)
        
        train_optimizer.zero_grad()
        loss.backward()
        train_optimizer.step()

        net.update_moving_average()

        total_num += data_loader.batch_size
        total_loss += loss.item() * data_loader.batch_size
        train_bar.set_description('Train Epoch: [{}/{}], lr: {:.6f}, Loss: {:.4f}'.format(epoch, model_config["EPOCHS"], optimizer.param_groups[0]['lr'], total_loss / total_num))

    return total_loss / total_num

# lr scheduler for training
def adjust_learning_rate(optimizer, epoch):
    """Decay the learning rate based on schedule"""
    lr = model_config["LEARNING_RATE"]
    if True:  # cosine lr schedule
        lr *= 0.5 * (1. + math.cos(math.pi * epoch / model_config["EPOCHS"]))
    else:  # stepwise lr schedule
        for milestone in args.schedule:
            lr *= 0.1 if epoch >= milestone else 1.
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [7]:
# test using a knn monitor
def test(net, memory_data_loader, test_data_loader, epoch):
    net.eval()
    classes = len(memory_data_loader.dataset.classes)
    total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, []
    with torch.no_grad():
        # generate feature bank
        for data, target in tqdm(memory_data_loader, desc='Feature extracting'):
            feature = net(data.cuda(non_blocking=True))
            feature = F.normalize(feature, dim=1)
            feature_bank.append(feature)
        # [D, N]
        feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()
        # [N]
        feature_labels = torch.tensor(memory_data_loader.dataset.targets, device=feature_bank.device)
        # loop test data to predict the label by weighted knn search
        test_bar = tqdm(test_data_loader)
        for data, target in test_bar:
            data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
            feature = net(data)
            feature = F.normalize(feature, dim=1)
            
            pred_labels = knn_predict(feature, feature_bank, feature_labels, classes, 200, 0.1)

            total_num += data.size(0)
            total_top1 += (pred_labels[:, 0] == target).float().sum().item()
            test_bar.set_description('Test Epoch: [{}/{}] Acc@1:{:.2f}%'.format(epoch, model_config["EPOCHS"], total_top1 / total_num * 100))

    return total_top1 / total_num * 100

# knn monitor as in InstDisc https://arxiv.org/abs/1805.01978
# implementation follows http://github.com/zhirongw/lemniscate.pytorch and https://github.com/leftthomas/SimCLR
def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t):
    # compute cos similarity between each feature vector and feature bank ---> [B, N]
    sim_matrix = torch.mm(feature, feature_bank)
    # [B, K]
    sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1)
    # [B, K]
    sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1), dim=-1, index=sim_indices)
    sim_weight = (sim_weight / knn_t).exp()

    # counts for each class
    one_hot_label = torch.zeros(feature.size(0) * knn_k, classes, device=sim_labels.device)
    # [B*K, C]
    one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0)
    # weighted score ---> [B, C]
    pred_scores = torch.sum(one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1), dim=1)

    pred_labels = pred_scores.argsort(dim=-1, descending=True)
    return pred_labels

In [8]:
def get_params_groups(model):
    regularized = []
    not_regularized = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            # print(name)
            continue
        # we do not regularize biases nor Norm parameters
        if name.endswith(".bias") or len(param.shape) == 1:
            not_regularized.append(param)
        else:
            regularized.append(param)
    return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}]


In [9]:
# define optimizer
# optimizer = torch.optim.SGD(get_params_groups(model), lr=0.06, weight_decay=5e-4, momentum=0.9)
optimizer, _ = get_optimizer(get_params_groups(model), conf={}, resume=False, ckpt=ckpt, optimizer="SGD", lr0=0.06, weight_decay=5e-4, momentum=0.9)


# load model if resume
epoch_start = 1

# if resume is not '':
#     checkpoint = torch.load(resume)
#     model.load_state_dict(checkpoint['state_dict'])
#     optimizer.load_state_dict(checkpoint['optimizer'])
#     epoch_start = checkpoint['epoch'] + 1
#     print('Loaded from: {}'.format(model_config["RESUME"]))

# logging
results = {'train_loss': [], 'test_acc@1': []}
if not os.path.exists(model_config["SAVE_DIR"]):
    os.mkdir(model_config["SAVE_DIR"])
# dump args
# with open(results_dir + '/args.json', 'w') as fid:
#     json.dump(__dict__, fid, indent=2)
model.cuda()
model.target.requires_grad_(False)

# training loop
for epoch in range(epoch_start, model_config["EPOCHS"] + 1):
    train_loss = train(model, train_dataloader, optimizer, epoch)
    results['train_loss'].append(train_loss)
    test_acc_1 = test(model.online, train_val_dataloader, test_dataloader, epoch)
    results['test_acc@1'].append(test_acc_1)
    # save statistics
    data_frame = pd.DataFrame(data=results, index=range(epoch_start, epoch + 1))
    data_frame.to_csv(model_config["SAVE_DIR"] + '/log.csv', index_label='epoch')
    # save model
    torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer' : optimizer.state_dict(),}, model_config["SAVE_DIR"] + '/model_last.pth')

Train Epoch: [1/800], lr: 0.060000, Loss: 2.6510:   3%|▎         | 3/97 [00:06<03:35,  2.29s/it]


KeyboardInterrupt: 