In [1]:
%load_ext autoreload

In [2]:
%autoreload

import os
import numpy as np
import glob
import cv2
import sys
import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision
import torchvision.transforms as transforms
from torchvision.io import read_image
from torchvision.utils import save_image
from torchvision import transforms
import torchvision.models as models
from torchvision import transforms

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

import pytorch_lightning as pl
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder

from pathlib import Path
import argparse
import json
import math
import random
import signal
import subprocess
import time

from PIL import Image, ImageOps, ImageFilter

import matplotlib.pyplot as plt
%matplotlib inline


print('torch version:', torch.__version__)
print('torchvision version:', torchvision.__version__)
print('pytorch lightning version:', pl.__version__)

torch version: 1.10.2+cu102
torchvision version: 0.11.3+cu102
pytorch lightning version: 1.5.1


In [3]:
    parser = argparse.ArgumentParser(description='Barlow Twins Training')
    parser.add_argument('data', type=Path, metavar='DIR',
                        help='path to dataset')
    parser.add_argument('--workers', default=8, type=int, metavar='N',
                        help='number of data loader workers')
    parser.add_argument('--epochs', default=1000, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('--batch-size', default=4096, type=int, metavar='N',
                        help='mini-batch size')
    parser.add_argument('--learning-rate-weights', default=0.2, type=float, metavar='LR',
                        help='base learning rate for weights')
    parser.add_argument('--learning-rate-biases', default=0.0048, type=float, metavar='LR',
                        help='base learning rate for biases and batch norm parameters')
    parser.add_argument('--weight-decay', default=1e-6, type=float, metavar='W',
                        help='weight decay')
    parser.add_argument('--lambd', default=0.0051, type=float, metavar='L',
                        help='weight on off-diagonal terms')
    # parser.add_argument('--projector', default='8192-8192-8192', type=str,
    #                     metavar='MLP', help='projector MLP')
    parser.add_argument('--projector', default='1024-1024-1024', type=str,
                        metavar='MLP', help='projector MLP')
    parser.add_argument('--print-freq', default=100, type=int, metavar='N',
                        help='print frequency')
    parser.add_argument('--checkpoint-dir', default='./checkpoint/', type=Path,
                        metavar='DIR', help='path to checkpoint directory')

    args = parser.parse_args(" ") 
    print(args)



Namespace(batch_size=4096, checkpoint_dir=PosixPath('checkpoint'), data=PosixPath(' '), epochs=1000, lambd=0.0051, learning_rate_biases=0.0048, learning_rate_weights=0.2, print_freq=100, projector='1024-1024-1024', weight_decay=1e-06, workers=8)


In [5]:
args.workers

8

In [17]:
# transform = transforms.Compose(
#     [transforms.ToTensor(),
#      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

from BarlowTwins_CIFAR10_classifier_dataset import *

geo_transforms = A.Compose(
    [
#      A.HorizontalFlip(p=0.5),
#      A.VerticalFlip(p=0.5),
     A.Normalize(),
     ToTensorV2(),
    ]
)

pixel_transforms = A.Compose(
    [
     ToTensorV2(),
    ]
)

batch_size = 200
ds_train = BT_CIFAR10_Classify_Dataset(train=True, geo_transform=geo_transforms, pixel_transform=pixel_transforms)
trainloader = torch.utils.data.DataLoader(ds_train, batch_size=batch_size, shuffle=True, num_workers=2)
print('ds_train length:', ds_train.__len__())

ds_val = BT_CIFAR10_Classify_Dataset(train=False, geo_transform=geo_transforms, pixel_transform=pixel_transforms)
valloader = torch.utils.data.DataLoader(ds_val, batch_size=batch_size, shuffle=False, num_workers=2) 
print('ds_val length:', ds_val.__len__())


Files already downloaded and verified
ds_train length: 50000
Files already downloaded and verified
ds_val length: 10000


In [10]:
# batch_size = 200

# ds_train = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
# trainloader = torch.utils.data.DataLoader(ds_train, batch_size=batch_size, shuffle=True, num_workers=2)

# ds_val = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
# valloader = torch.utils.data.DataLoader(ds_val, batch_size=batch_size, shuffle=False, num_workers=2)    


Files already downloaded and verified
Files already downloaded and verified


In [18]:
from resnet50_cifar10_classifier import *
#model = resnet50_cifar10_classifier(ds_train, ds_val, args, base_barlow_model_encoder=None)
model = resnet50_cifar10_classifier.load_from_checkpoint(checkpoint_path='./lightning_logs/barlow_classifier/version_1/checkpoints/epoch=50-step=662.ckpt', 
                                   ds_train=ds_train, ds_val=ds_val, args=args, base_barlow_model_encoder=None)
# model = torch.load('./lightning_logs/barlow_classifier/version_1/checkpoints/epoch=66-step=870.ckpt')





batch size: 4096
Creating new base encoder


In [20]:
import torch
import torchmetrics
from torchmetrics import Accuracy

# initialize metric
accuracy = Accuracy(num_classes=10, threshold=0.5, top_k=1)

for (i, batch) in enumerate(valloader):
    x, target = batch
    #target = (torch.nn.functional.one_hot(target, num_classes=10))

    preds = model.forward(x)
    #preds = (torch.argmax(preds, dim=-1)).type(torch.int64) #.unsqueeze(-1)).type(torch.int64)    
    
    # metric on current batch
    accuracy.update(preds, target)
    
    if i % 10 == 0:
        print('batch:', i)
    

# metric on all batches using custom accumulation
print('accuracy:', accuracy.compute())


batch: 0
batch: 10
batch: 20
batch: 30
batch: 40
accuracy: tensor(0.5442)


In [8]:
parser = argparse.ArgumentParser(description='Barlow Twins Training')
parser.add_argument('data', type=Path, metavar='DIR',
                    help='path to dataset')
parser.add_argument('--workers', default=8, type=int, metavar='N',
                    help='number of data loader workers')
parser.add_argument('--epochs', default=1000, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--batch-size', default=1024, type=int, metavar='N',
                    help='mini-batch size')
parser.add_argument('--learning-rate-weights', default=0.2, type=float, metavar='LR',
                    help='base learning rate for weights')
parser.add_argument('--learning-rate-biases', default=0.0048, type=float, metavar='LR',
                    help='base learning rate for biases and batch norm parameters')
parser.add_argument('--weight-decay', default=1e-6, type=float, metavar='W',
                    help='weight decay')
parser.add_argument('--lambd', default=0.0051, type=float, metavar='L',
                    help='weight on off-diagonal terms')
parser.add_argument('--projector', default='8192-8192-8192', type=str,
                    metavar='MLP', help='projector MLP')
parser.add_argument('--print-freq', default=100, type=int, metavar='N',
                    help='print frequency')
parser.add_argument('--checkpoint-dir', default='./checkpoint/', type=Path,
                    metavar='DIR', help='path to checkpoint directory')

args = parser.parse_args(" ") 
print(args)


Namespace(batch_size=1024, checkpoint_dir=PosixPath('checkpoint'), data=PosixPath(' '), epochs=1000, lambd=0.0051, learning_rate_biases=0.0048, learning_rate_weights=0.2, print_freq=100, projector='8192-8192-8192', weight_decay=1e-06, workers=8)


In [9]:
from barlowtwins_resnet50 import *

# Load previously trained barlowtwins_resnet50 model
# Expects to have base barlow model initialized from checkpoint
bt_resnet50_chkpt = './lightning_logs/barlow/version_1/checkpoints/epoch=58-step=2890.ckpt'
blah = barlowtwins_resnet50.load_from_checkpoint(checkpoint_path=bt_resnet50_chkpt, ds_train=None, args=args)
blah.freeze()
encoder = blah.encoder


In [10]:
print(encoder)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [13]:
model_parameters = filter(lambda p: p.requires_grad, blah.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print('Number of trainable params in model:', params)

Number of trainable params in model: 0
