In [1]:
gpu_info = !nvidia-smi -i 0
gpu_info = '\n'.join(gpu_info)
print(gpu_info)

from datetime import datetime
from functools import partial
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models import resnet
from tqdm import tqdm
import argparse
import json
import math
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from configs import model_config

Sat May 20 13:51:25 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 517.00       Driver Version: 517.00       CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name            TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ... WDDM  | 00000000:01:00.0  On |                  N/A |
| N/A   58C    P8    16W /  N/A |   2222MiB /  6144MiB |     41%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
root_path = "D:\Ai\Projects\self-supervised-learning\data"

class CIFAR10Pair(CIFAR10):
    """CIFAR10 Dataset.
    """
    def __getitem__(self, index):
        img = self.data[index]
        img = Image.fromarray(img)

        if self.transform is not None:
            im_1 = self.transform(img)
            im_2 = self.transform(img)

        return im_1, im_2

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(32),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])

# data prepare
train_data = CIFAR10Pair(root=root_path, train=True, transform=train_transform, download=True)
train_dataloader = DataLoader(train_data, batch_size=model_config["batch_size"], shuffle=True, num_workers=0, pin_memory=True, drop_last=True)

memory_data = CIFAR10(root=root_path, train=True, transform=test_transform, download=True)
train_val_dataloader = DataLoader(memory_data, batch_size=model_config["batch_size"], shuffle=False, num_workers=0, pin_memory=True)

test_data = CIFAR10(root=root_path, train=False, transform=test_transform, download=True)
test_dataloader = DataLoader(test_data, batch_size=model_config["batch_size"], shuffle=False, num_workers=0, pin_memory=True)


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [3]:

import torch.nn as nn
import torch
import torchvision.utils
import torchvision
import torch.nn.functional as F

import copy
from utils import EMA
from configs import model_config

#create the Siamese Neural Network
class MOCO(nn.Module):

    def __init__(self, in_features=512, hidden_size=4096, embedding_size=256, projection_size=256, projection_hidden_size=2048, batch_norm_mlp=True):
        super(MOCO, self).__init__()
        self.online = self.get_representation()
        self.online.mean = nn.Linear(model_config["EMBEDDING_SIZE"], model_config["EMBEDDING_SIZE"])
        self.online.var = nn.Linear(model_config["EMBEDDING_SIZE"], model_config["EMBEDDING_SIZE"])
        self.predictor = self.get_linear_block()

        self.target = self.get_target()
        self.ema = EMA(0.999)

        self.LeakyReLU = nn.LeakyReLU(0.2)
    
    @torch.no_grad()
    def get_target(self):
        return copy.deepcopy(self.online)

    def get_linear_block(self):
        return nn.Sequential(
            nn.Linear(model_config["EMBEDDING_SIZE"], model_config["HIDDEN_SIZE"]),
            nn.BatchNorm1d(model_config["HIDDEN_SIZE"]),
            nn.ReLU(inplace=True),
            nn.Linear(model_config["HIDDEN_SIZE"], model_config["EMBEDDING_SIZE"])
        )

    def get_representation(self):
        return torchvision.models.resnet50(num_classes=model_config["EMBEDDING_SIZE"])

    @torch.no_grad()
    def update_moving_average(self):
        for online_params, target_params in zip(self.online.parameters(), self.target.parameters()):
            old_weight, up_weight = target_params.data, online_params.data
            target_params.data = self.ema.update_average(old_weight, up_weight)
            
    def reparameterization(self, mean, logvar):
        var = torch.exp(0.5*logvar)
        epsilon = torch.randn_like(var)      # sampling epsilon        
        z = mean + var * epsilon                          # reparameterization trick
        return z

    def byol_loss(self, x, y):
        # L2 normalization
        x = F.normalize(x, dim=-1, p=2)
        y = F.normalize(y, dim=-1, p=2)
        loss = 2 - 2 * (x * y).sum(dim=-1)
        return loss

    def iso_kl(self, mean, log_var):
        # indices = find_inf_nan_indices(kl)
        # print(indices)
        # print(log_var)
        # if torch.isnan(kl) or torch.isinf(kl):
        # print("log_var: ", log_var)
        # print("log_var.exp: ", log_var.exp())
        # print("mean.pow: ", mean.pow(2))
        return - 0.5 * torch.sum(1+ log_var - mean.pow(2) - log_var.exp())

    def kl_divergence(self, mu1, log_var1, mu2, log_var2):
        var1 = torch.exp(log_var1)
        var2 = torch.exp(log_var2)

        term1 = (var1 / var2 - 1).sum(dim=1)
        term2 = ((mu2 - mu1).pow(2) / var2).sum(dim=1)
        term3 = (log_var2 - log_var1).sum(dim=1)
        kl_div = 0.5 * (term1 + term2 + term3)
        
        return kl_div.sum()

    def forward_once(self, x):
        embedding_o = self.online(x)
        mean_o = self.online.mean(self.LeakyReLU(embedding_o))
        logvar_o = self.online.var(self.LeakyReLU(embedding_o)),
        z_o = self.reparameterization(mean_o, logvar_o)
        z_o_p = self.predictor(z_o)

        with torch.no_grad():
            embedding_tar = self.target(x).detach()
            mean_tar = self.target.mean(self.LeakyReLU(embedding_tar)).detach()
            logvar_tar = self.target.var(self.LeakyReLU(embedding_tar)).detach()
            z_tar = self.reparameterization(mean_tar, logvar_tar).detach()

        distance_loss = self.byol_loss(z_o_p, z_tar)

        kl_loss = self.kl_divergence(mean_o, logvar_o, mean_tar, logvar_tar)

        iso_kl_loss = self.iso_kl(mean_o, logvar_o)
        iso_kl_loss += self.iso_kl(mean_tar, logvar_tar)

        # if torch.isnan(kl_loss) or torch.isinf(kl_loss):
        #     print("------------------------")
        #     print("kl_total: ", kl_loss)
            # print("logvar_o: ", logvar_o)
            # print("logvar_tar: ", logvar_tar)


        return kl_loss, distance_loss, iso_kl_loss, embedding_o
        

    def forward(self, x1, x2=None):
        if x2 is None:
            return self.online(x1)

        kl_loss1, distance_loss1, iso_kl_loss1, embedding_o1 = self.forward_once(x1)
        kl_loss2, distance_loss2, iso_kl_loss2, embedding_o2 = self.forward_once(x2)

        

        kl_total = kl_loss1 + kl_loss2
        iso_kl_total = iso_kl_loss1 + iso_kl_loss2
        distance_total = (distance_loss1 + distance_loss2).mean()

        total_loss =  distance_total  + iso_kl_total

        print("kl_loss: ", kl_total)
        print("distance_total: ", distance_total)
        print("iso_kl_total: ", iso_kl_total)

        return total_loss

def find_inf_nan_indices(tensor):
    # Check for inf and nan values
    is_inf = torch.isinf(tensor)
    is_nan = torch.isnan(tensor)

    # Combine the masks to find the indices where either inf or nan is present
    inf_nan_mask = torch.logical_or(is_inf, is_nan)

    # Get the indices where inf or nan is present
    indices = torch.nonzero(inf_nan_mask)

    return indices
# temp1 = torch.rand((6, 3, 32, 32))
# temp2 = torch.rand((6, 3, 32, 32))
# temp_model = MOCO()
# ress = temp_model(temp1, temp2)[0]
# for res in ress:
#     print(res[0].size(), res[1].size(), res[2].size(), res[3].size())
model = MOCO().cuda()

In [4]:
# train for one epoch
def train(net, data_loader, train_optimizer, epoch):
    net.train()
    adjust_learning_rate(optimizer, epoch)

    total_loss, total_num, train_bar = 0.0, 0, tqdm(data_loader)
    for im_1, im_2 in train_bar:
        im_1, im_2 = im_1.cuda(non_blocking=True), im_2.cuda(non_blocking=True)

        loss = net(im_1, im_2)
        
        train_optimizer.zero_grad()
        loss.backward()
        train_optimizer.step()

        net.update_moving_average()

        total_num += data_loader.batch_size
        total_loss += loss.item() * data_loader.batch_size
        train_bar.set_description('Train Epoch: [{}/{}], lr: {:.6f}, Loss: {:.4f}'.format(epoch, model_config["EPOCHS"], optimizer.param_groups[0]['lr'], total_loss / total_num))

    return total_loss / total_num

# lr scheduler for training
def adjust_learning_rate(optimizer, epoch):
    """Decay the learning rate based on schedule"""
    lr = model_config["LEARNING_RATE"]
    if True:  # cosine lr schedule
        lr *= 0.5 * (1. + math.cos(math.pi * epoch / model_config["EPOCHS"]))
    else:  # stepwise lr schedule
        for milestone in args.schedule:
            lr *= 0.1 if epoch >= milestone else 1.
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [5]:
# test using a knn monitor
def test(net, memory_data_loader, test_data_loader, epoch):
    net.eval()
    classes = len(memory_data_loader.dataset.classes)
    total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, []
    with torch.no_grad():
        # generate feature bank
        for data, target in tqdm(memory_data_loader, desc='Feature extracting'):
            feature = net(data.cuda(non_blocking=True))
            feature = F.normalize(feature, dim=1)
            feature_bank.append(feature)
        # [D, N]
        feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()
        # [N]
        feature_labels = torch.tensor(memory_data_loader.dataset.targets, device=feature_bank.device)
        # loop test data to predict the label by weighted knn search
        test_bar = tqdm(test_data_loader)
        for data, target in test_bar:
            data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
            feature = net(data)
            feature = F.normalize(feature, dim=1)
            
            pred_labels = knn_predict(feature, feature_bank, feature_labels, classes, 200, 0.1)

            total_num += data.size(0)
            total_top1 += (pred_labels[:, 0] == target).float().sum().item()
            test_bar.set_description('Test Epoch: [{}/{}] Acc@1:{:.2f}%'.format(epoch, model_config["EPOCHS"], total_top1 / total_num * 100))

    return total_top1 / total_num * 100

# knn monitor as in InstDisc https://arxiv.org/abs/1805.01978
# implementation follows http://github.com/zhirongw/lemniscate.pytorch and https://github.com/leftthomas/SimCLR
def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t):
    # compute cos similarity between each feature vector and feature bank ---> [B, N]
    sim_matrix = torch.mm(feature, feature_bank)
    # [B, K]
    sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1)
    # [B, K]
    sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1), dim=-1, index=sim_indices)
    sim_weight = (sim_weight / knn_t).exp()

    # counts for each class
    one_hot_label = torch.zeros(feature.size(0) * knn_k, classes, device=sim_labels.device)
    # [B*K, C]
    one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0)
    # weighted score ---> [B, C]
    pred_scores = torch.sum(one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1), dim=1)

    pred_labels = pred_scores.argsort(dim=-1, descending=True)
    return pred_labels

In [6]:
def get_params_groups(model):
    regularized = []
    not_regularized = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            # print(name)
            continue
        # we do not regularize biases nor Norm parameters
        if name.endswith(".bias") or len(param.shape) == 1:
            not_regularized.append(param)
        else:
            regularized.append(param)
    return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}]


In [7]:
model.target.requires_grad_(False)

# define optimizer
optimizer = torch.optim.SGD(get_params_groups(model), lr=0.06, weight_decay=5e-4, momentum=0.9)

# load model if resume
epoch_start = 1

# if resume is not '':
#     checkpoint = torch.load(resume)
#     model.load_state_dict(checkpoint['state_dict'])
#     optimizer.load_state_dict(checkpoint['optimizer'])
#     epoch_start = checkpoint['epoch'] + 1
#     print('Loaded from: {}'.format(model_config["RESUME"]))

# logging
results = {'train_loss': [], 'test_acc@1': []}
# if not os.path.exists(model_config["SAVE_DIR"]):
#     os.mkdir(model_config["SAVE_DIR"])
# dump args
# with open(results_dir + '/args.json', 'w') as fid:
#     json.dump(__dict__, fid, indent=2)


# training loop
for epoch in range(epoch_start, model_config["EPOCHS"] + 1):
    train_loss = train(model, train_dataloader, optimizer, epoch)
    results['train_loss'].append(train_loss)
    test_acc_1 = test(model.online, train_val_dataloader, test_dataloader, epoch)
    results['test_acc@1'].append(test_acc_1)
    # save statistics
    data_frame = pd.DataFrame(data=results, index=range(epoch_start, epoch + 1))
    # data_frame.to_csv(model_config["SAVE_DIR"] + '/log.csv', index_label='epoch')
    # save model
    # torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer' : optimizer.state_dict(),}, model_config["SAVE_DIR"] + '/model_last.pth')

  0%|          | 0/97 [00:00<?, ?it/s]

kl_loss:  tensor(0., device='cuda:0', grad_fn=<AddBackward0>)
distance_total:  tensor(4.0214, device='cuda:0', grad_fn=<MeanBackward0>)
iso_kl_total:  tensor(66691.3438, device='cuda:0', grad_fn=<AddBackward0>)


Train Epoch: [1/800], lr: 0.050000, Loss: 66695.3672:   1%|          | 1/97 [00:03<06:17,  3.94s/it]

kl_loss:  tensor(inf, device='cuda:0', grad_fn=<AddBackward0>)
distance_total:  tensor(2.4646, device='cuda:0', grad_fn=<MeanBackward0>)
iso_kl_total:  tensor(1.8806e+20, device='cuda:0', grad_fn=<AddBackward0>)


Train Epoch: [1/800], lr: 0.050000, Loss: 94027938627124756480.0000:   2%|▏         | 2/97 [00:05<03:55,  2.48s/it]

kl_loss:  tensor(nan, device='cuda:0', grad_fn=<AddBackward0>)
distance_total:  tensor(nan, device='cuda:0', grad_fn=<MeanBackward0>)
iso_kl_total:  tensor(nan, device='cuda:0', grad_fn=<AddBackward0>)


Train Epoch: [1/800], lr: 0.050000, Loss: nan:   3%|▎         | 3/97 [00:06<03:06,  1.99s/it]                      

kl_loss:  tensor(nan, device='cuda:0', grad_fn=<AddBackward0>)
distance_total:  tensor(nan, device='cuda:0', grad_fn=<MeanBackward0>)
iso_kl_total:  tensor(nan, device='cuda:0', grad_fn=<AddBackward0>)


Train Epoch: [1/800], lr: 0.050000, Loss: nan:   4%|▍         | 4/97 [00:08<02:42,  1.74s/it]

kl_loss:  tensor(nan, device='cuda:0', grad_fn=<AddBackward0>)
distance_total:  tensor(nan, device='cuda:0', grad_fn=<MeanBackward0>)
iso_kl_total:  tensor(nan, device='cuda:0', grad_fn=<AddBackward0>)


Train Epoch: [1/800], lr: 0.050000, Loss: nan:   5%|▌         | 5/97 [00:09<02:59,  1.96s/it]


KeyboardInterrupt: 