In [1]:
from random import randint, random

import numpy as np
import math
import time
from collections import deque
import torch as T
import torch.nn.functional as F
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
import sklearn.neighbors
import sklearn.metrics
import matplotlib.pyplot as plt
import PIL
from nvidia.dali.plugin.pytorch import DALIGenericIterator

from torchvision.models.vision_transformer import VisionTransformer

from imageComponent import *
from remote_read_write import scp_read_wrapper, scp_write_wrapper

dali_device = 1
device = 'cuda:1'
T.backends.cudnn.benchmark = True
use_amp = True



In [2]:
model = ViTCifar()
root = '/public/home/ly_1112103017/zyz/imageCL/ckpt/CLCifarViT2'

In [3]:
dataset = '/public/ly/zyz/imageCL/dataset/cifar10-image/train'

In [4]:
scp_Tload = scp_read_wrapper(T.load, 'ly_1112103017@172.16.35.121', 30907, '/public/ly/zyz/cluster_id_rsa', '/public/ly/zyz/imageCL/tmp')
scp_Tsave = scp_write_wrapper(T.save, 'ly_1112103017@172.16.35.121', 30907, '/public/ly/zyz/cluster_id_rsa', '/public/ly/zyz/imageCL/tmp')

In [5]:
size = (32, 32)

batch_size = 700
lr = 1e-5
epoch = 2400
strength = 1
random_crop_area = [0.2, 1.0]

warmlr = 1e-3
warmep = 150
eta_min = 1e-2

neg_sample = 128000

In [6]:
pipe = DALICLImageFolders(dataset, size, random_crop_area, strength, batch_size = batch_size, num_threads=8, device_id=dali_device)
loader = dali.plugin.pytorch.DALIGenericIterator([pipe], reader_name = 'reader', output_map = ['img1', 'img2', 'label'],
                                                 last_batch_policy = dali.plugin.base_iterator.LastBatchPolicy.DROP)

In [7]:
moco = MoCo(model, math.ceil(neg_sample / batch_size))
moco = moco.to(device, memory_format = T.channels_last)
optim = T.optim.AdamW(moco.parameters(), lr = lr * batch_size / 256)
scheduler1 = T.optim.lr_scheduler.LambdaLR(optim, lambda x:min(1, warmlr + x * (1 - warmlr) / warmep))
scheduler2 = T.optim.lr_scheduler.CosineAnnealingWarmRestarts(optim, epoch, T_mult = 2, eta_min = eta_min)
scheduler = T.optim.lr_scheduler.SequentialLR(optim, schedulers = [scheduler1, scheduler2], milestones = [warmep])

In [None]:
scaler = T.cuda.amp.GradScaler(enabled = use_amp)
scp_Tsave(moco.base_encoder.state_dict(), '%s/%04d.pth' %(root, 0))
moco = moco.to(device, memory_format = T.channels_last)
loss_r = []
for e in range(epoch):
    l = 0
    t = time.time()
    for data in loader:
        with T.no_grad():
            img1 = data[0]['img1']
            img2 = data[0]['img2']
            img1, img2= img1.to(device, memory_format = T.channels_last), img2.to(device, memory_format = T.channels_last)
        with T.autocast(device_type = 'cuda', dtype = T.float16, enabled = use_amp):
            loss = moco(img1, img2)
        scaler.scale(loss).backward()
        scaler.step(optim)
        scaler.update()
        optim.zero_grad(set_to_none = True)
        l = l + loss.item()
    scp_Tsave(moco.base_encoder.state_dict(), '%s/%04d.pth' %(root, e + 1))
    print('epoch:%d\tloss:%f' %(e, l))
    print(time.time() - t, end = '\n\n')
    loss_r.append(l)
    scp_Tsave(loss_r, '%s/loss.pth' %root)
    scheduler.step()

epoch:0	loss:560.794092
60.24011850357056

epoch:1	loss:636.715875
61.20248746871948

epoch:2	loss:662.802594
58.966673851013184

epoch:3	loss:664.968444
58.80326199531555

epoch:4	loss:663.240236
57.9626624584198

epoch:5	loss:661.091111
57.85875964164734

epoch:6	loss:658.919560
57.66527032852173

epoch:7	loss:657.070302
57.59921836853027

epoch:8	loss:654.873638
57.15377402305603

epoch:9	loss:653.053690
57.62680006027222

epoch:10	loss:652.104828
57.43871307373047

