In [None]:
%reload_ext watermark
%reload_ext autoreload
%autoreload 2
%watermark -v -p numpy,sklearn,pandas
%watermark -v -p cv2,PIL,matplotlib
%watermark -v -p torch,torchvision,torchaudio,pytorch_lightning
%matplotlib inline
%config InlineBackend.figure_format='retina'
%config IPCompleter.use_jedi = False


from IPython.display import display, HTML, Javascript
display(HTML('<style>.container { width:%d%% !important; }</style>' % 90))

def _IMPORT_(x):
    try:
        exec(x, globals())
    except:
        pass

## Import Module

In [None]:
###
### Common ###
###

import os, io, time, random, math, base64

_IMPORT_('import numpy as np')
_IMPORT_('import pandas as pd')
_IMPORT_('from tqdm.notebook import tqdm')


###
### Torch ###
###

_IMPORT_('import torch')
_IMPORT_('import torch.nn as nn')
_IMPORT_('import torch.nn.functional as F')
_IMPORT_('import torch.optim as O')
_IMPORT_('from torchvision import models as M')
_IMPORT_('from torchvision import transforms as T')
_IMPORT_('from torch.utils.data import Dataset, DataLoader')


###
### Viz Model ###
###

_IMPORT_('import wandb')
_IMPORT_('import hiddenlayer as hl')
_IMPORT_('from graphviz import Digraph, Source')
_IMPORT_('from torchviz import make_dot')
_IMPORT_('from torchsummary import summary')


###
### Display ###
###

_IMPORT_('import cv2')
_IMPORT_('from PIL import Image')
_IMPORT_('from torchvision.utils import make_grid')
_IMPORT_('import matplotlib.pyplot as plt')
_IMPORT_('import plotly')
_IMPORT_('import plotly.graph_objects as go')

plotly.offline.init_notebook_mode(connected=False)

def show_video(video_path, width=None, height=None):
    W, H = '', ''
    if width:
        W = 'width=%d' % width
    if height:
        H = 'height=%d' % height
    mp4 = open(video_path, 'rb').read()
    data_url = 'data:video/mp4;base64,' + base64.b64encode(mp4).decode()
    return HTML('<video %s %s controls src="%s" type="video/mp4"/>' % (W, H, data_url))

def show_image(image_path, width=None, height=None):
    W, H = '', ''
    if width:
        W = 'width=%d' % width
    if height:
        H = 'height=%d' % height
    img = open(image_path, 'rb').read()
    data_url = 'data:image/jpg;base64,' + base64.b64encode(img).decode()
    return HTML('<img %s %s src="%s"/>' % (W, H, data_url))

###
### Random Seed ###
###

def  set_rng_seed(x):
    try:
        random.seed(x)
        np.random.seed(x)
        torch.manual_seed(x)
    except: 
        pass

set_rng_seed(888)


In [None]:
FRAME_WIDTH, FRAME_HEIGHT = 112, 112
NUM_FRAMES = 64
NUM_DMODEL = 512
REP_OUT_TIME_RATE = 0.12
DATASET_PREFIX = '/data/datasets/cv/countix'

## Install Depends Libraries

In [None]:
!apt install ffmpeg
!pip3 install youtube_dl

## Data Process

### Countix Dataset Download and Crop

```
    vs           cs                ce            ve
    |             |0.5           0.5|             |
    |-------|---------|---------|------|----------|
            |         |         |      |            
            ks       rs        re      ke
    
vs: the video start
ve: the video end
ks: the kinetics start
ke: the kinetics end
rs: repetition start
re: repetition end
cs: clip video start
ce: clip video end 

```

In [None]:
import youtube_dl

YOUTUBE_PREFIX = 'https://www.youtube.com/watch?v='

SOCKS5_PROXY = 'socks5://127.0.0.1:1881'

YDL_OPTS = {
    'format': 'mp4',
    'proxy': SOCKS5_PROXY,
    'quiet': True,
    'max_filesize': 30000000, # 30MB
}

def video_download_crop(vid, fps, wh, ss, to, raw_dir, out_dir, force=False):
    raw_file = f'{raw_dir}/{vid}.mp4'
    out_file = '%s/%s_%010.6f_%010.6f.mp4' % (out_dir, vid, ss, to)

    if os.path.exists(out_file):
        if force:
            os.remove(out_file)
        return out_file

    if not os.path.exists(raw_file):
        YDL_OPTS['outtmpl'] = raw_file
        with youtube_dl.YoutubeDL(YDL_OPTS) as ydl:
            ydl.download([f'{YOUTUBE_PREFIX}{vid}'])

    if os.path.exists(raw_file):
        cmd = 'ffmpeg -i %s -v 0 -r %f -s %s -ss %s -to %s -an %s' % (
                raw_file, fps, wh, ss, to, out_file)
        subprocess.call(cmd, shell=True)
        return out_file

    return None

def data_preprocess(data_prefix, phase, force=False):
    df = pd.read_csv(f'{data_prefix}/countix_{phase}.csv')
    raw_dir = f'{data_prefix}/raw/{phase}'
    out_dir = f'{data_prefix}/{phase}'
    os.makedirs(raw_dir, exist_ok=True)
    os.makedirs(out_dir, exist_ok=True)
    df['file_name'] = None
    df['rep_start_frame'] = 0
    df['rep_end_frame'] = 0
    for idx, row in df.iterrows():
        if phase == 'test' or phase == 'sample':
            vid, ks, ke, rs, re, count, file_name, rsf, rse = row
        else:
            vid, _, ks, ke, rs, re, count, file_name, rsf, rse = row

        interval = re - rs
        cs = float(max([ks, rs - REP_OUT_TIME_RATE * interval]))
        ce = float(min([ke, re + REP_OUT_TIME_RATE * interval]))
        try:
            fps = NUM_FRAMES / (ce - cs)
            out_file = video_download_crop(vid, 
                    fps, '%dx%d' % (FRAME_WIDTH, FRAME_HEIGHT), cs, ce, raw_dir, out_dir, force)
            if out_file is not None:
                cap = cv2.VideoCapture(out_file)
                cnt = cap.get(cv2.CAP_PROP_FRAME_COUNT)
                if cnt >= NUM_FRAMES:
                    print('preprocess file: %s: %d, count: %d' % (out_file, cnt, count))
                    rsf = int(fps * (rs - cs))
                    ref = int(fps * (re - cs))
                    if 1 < count < (ref - rsf) // 2:
                        df.loc[idx, 'rep_start_frame'] = rsf
                        df.loc[idx, 'rep_end_frame'] = ref
                        df.loc[idx, 'file_name'] = os.path.basename(out_file)
                    else:
                        print('[%s] count:[%d] %d %d' % (vid, count, ref, rsf))
                else:
                    print(f'frames is less than {NUM_FRAMES}')     
            else:
                print('download or crop [%s] fail' % vid)
        except Exception as err:
            print('%s' % err)
    sub_df = df[df['file_name'].notnull()]
    sub_df.to_csv(f'{data_prefix}/sub_countix_{phase}.csv', index=False, header=True)
    return sub_df

# data_preprocess(DATASET_PREFIX, 'test')
# data_preprocess(DATASET_PREFIX, 'val')
# data_preprocess(DATASET_PREFIX, 'train')

In [None]:
!ls $DATASET_PREFIX

### Countix Dataset Sample Display

In [None]:
df_train = pd.read_csv(f'{DATASET_PREFIX}/sub_countix_train.csv')

In [None]:
# Test
def calculate_period_length(row):
    row.period_length = (row.rep_end_frame - row.rep_start_frame) / row['count']
    return row
df_train['period_length'] = 0
df_train = df_train.apply(calculate_period_length, axis=1, result_type='expand')
df_train.describe()

In [None]:
df_train.head()

In [None]:
sample_video_item = df_train.iloc[3]
print(sample_video_item)
sample_video_path = f'{DATASET_PREFIX}/train/{sample_video_item.file_name}'

cap = cv2.VideoCapture(sample_video_path)
count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)

# fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
# vwrite = cv2.VideoWriter('/tmp/t3.mp4', fourcc=fourcc, fps=fps, frameSize=(width, height))
print(count, fps)
frames = []
while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break
    frames.append(frame)
    # vwrite.write(frame)
cap.release()
# vwrite.release()
print(len(frames))
show_video(sample_video_path, width=600)

### Countix Dataset Loader 

In [None]:
class CountixDataset(Dataset):
    
    def __init__(self, data_root, phase, frame_size=112, num_frames=64):
        self.data_root = data_root
        self.phase = phase
        self.num_frames = num_frames
        self.frame_size = (frame_size, frame_size) if isinstance(frame_size, int) else frame_size
        self.df = pd.read_csv(f'{data_root}/sub_countix_{phase}.csv')

    def __getitem__(self, index):
        item = self.df.iloc[index]
        start = item.rep_start_frame
        end = item.rep_end_frame
        count = self.df.loc[index, 'count']
        
        path = f'{self.data_root}/{self.phase}/{item.file_name}'
        
        period_length = int((end - start) / count)
        
        frames = []
        cap = cv2.VideoCapture(path)
        flg = 0
        while cap.isOpened():
            ret, frame = cap.read()
            if ret is False:
                if flg == 0:
                    print(path)
                break
            flg = 1
                
            img = Image.fromarray(frame)
            trans = T.Compose([
                T.Resize(self.frame_size),
                T.ToTensor(),       
                T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])])
            frames.append(trans(img).unsqueeze(0))
        cap.release()
        
        X = frames[:self.num_frames]
        X = torch.cat(X)
        
        y1 = np.full((self.num_frames, 1), fill_value=period_length)
        y2 = np.ones((self.num_frames, 1)) 
        for i in range(self.num_frames):
            if i < start or i > end:
                y1[i] = 0
                y2[i] = 0
                
        y1 = torch.FloatTensor(y1) # period length / per frame [2, 3, ..., N/2 ]
        y2 = torch.FloatTensor(y2) # periodicity [0, 1]
        y3 = torch.FloatTensor([count])
        return X, y1, y2, y3

    def __len__(self):
        return len(self.df)

In [None]:
train_dataset = CountixDataset(DATASET_PREFIX, 'train')
train_loader = DataLoader(train_dataset, batch_size=4, num_workers=1)

In [None]:
X, y1, y2, y3 = next(iter(train_loader))
X.shape, y1.shape, y2.shape, y3

## RepNet Model

In [None]:
show_image('repnet_model.png', width=1000)

### Encoder

#### Convolutional Feature Extractor

In [None]:
class ResNet50Base5D(nn.Module):
    def __init__(self, pretrained=False, m=2):
        super().__init__()
        base_model = M.resnet50(pretrained=pretrained)
        self.m = m
        
        if m == 1:
            # method-1:
            base_model.fc = nn.Identity()
            base_model.avgpool = nn.Identity()
            base_model.layer4 = nn.Identity()
            base_model.layer3[3] = nn.Identity()
            base_model.layer3[4] = nn.Identity()
            base_model.layer3[5] = nn.Identity()
            self.base_model = base_model
        else:
            # method-2:
            self.base_model = nn.Sequential(
                *list(base_model.children())[:-4],
                *list(base_model.children())[-4][:3])
 
    def forward(self, x):
        N, S, C, H, W = x.shape 
        x = x.reshape(-1, C, H, W)  # 5D -> 4D
        x = self.base_model(x)
        if self.m == 1:
            x = x.reshape(N, S, 1024, 7, 7)
        else:
            x = x.reshape(N, S, x.size(1), x.size(2), x.size(3))  # 4D -> 5D
        return x

In [None]:
resnet50 = ResNet50Base5D(pretrained=False)
resnet50

In [None]:
resnet50_outputs = resnet50(X)
resnet50_outputs.shape

In [None]:
g = resnet50_outputs[0][:, :1]
g = make_grid(g, padding=3)
g = g.detach().numpy().transpose((1, 2, 0))
plt.axis('off')
plt.imshow(g, plt.get_cmap('gray'))

#### Temporal Context

In [None]:
class TemporalContext(nn.Module):
    def __init__(self, in_channels=1024, out_channels=512):
        super().__init__()
        self.conv3D = nn.Sequential(
            nn.Conv3d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=3,
                padding=(3, 1, 1),
                dilation=(3, 1, 1)),
            nn.BatchNorm3d(out_channels),
            nn.ReLU())

    def forward(self, x):
        # (N, S, C:1024, H, W) -> (N, C/E, S, H, W): (4, 1024, 64, 7, 7)
        x = x.transpose(1, 2)
        x = self.conv3D(x)
        x = x.transpose(1, 2)
        return x

In [None]:
TC = TemporalContext()
tc_outputs = TC(resnet50_outputs)
tc_outputs.shape

#### Dimensionality Reduction

In [None]:
class GlobalMaxPool(nn.Module):
    def __init__(self, m=1):
        super().__init__()
        self.m = m
        
        # method:2
        if m == 2:
            self.pool = nn.MaxPool3d(kernel_size = (1, 7, 7))

    def forward(self, x):
        # Inputs: (N, C, S, 7, 7)
        # method:1
        if self.m == 1:
            x, _ = torch.max(x, dim=3)
            x, _ = torch.max(x, dim=3)
        else:
            # method:2
            x = self.pool(x).squeeze(3).squeeze(3)
        return x # (N, S, C)

In [None]:
GMP = GlobalMaxPool()
gmp_outputs = GMP(tc_outputs)
gmp_outputs.shape

#### Together

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn = ResNet50Base5D(pretrained=True)
        self.temporal_context = TemporalContext(in_channels=1024, out_channels=512)
        self.max_pool = GlobalMaxPool()
   
    def forward(self,x):
        x = self.cnn(x)
        x = self.temporal_context(x)
        x = self.max_pool(x)
        return x

encoder = Encoder()
encoder_outputs = encoder(X)
encoder_outputs.shape

### Temporal Self-similarity Matrix(TSM)

In [None]:
class TemproalSelfMatrix(nn.Module):
    def __init__(self, num_frames=64, temperature=13.544, m=1):
        super().__init__()
        self.m = m
        self.temperature = temperature
        self.register_buffer('zero_value', torch.tensor(0.0))
        self.register_buffer('one_value', torch.ones(num_frames))
        
    def calc_sims(self, x):
        # (N, S, E)  --> (N, 1, S, S)
        S = x.shape[1]
        
        I = self.one_value
        xr = torch.einsum('nse,h->nhse', (x, I))
        xc = torch.einsum('nse,h->nshe', (x, I))
        diff = xr - xc
        return torch.einsum('nsge,nsge->nsg', (diff, diff))
        
    def pairwise_l2_distance(self, x):
        # (S, E)
        a, b = x, x
        norm_a = torch.sum(torch.square(a), dim=1)
        norm_a = torch.reshape(norm_a, [-1, 1])
        norm_b = torch.sum(torch.square(b), dim=1)
        norm_b = torch.reshape(norm_b, [1, -1])
        b = torch.transpose(b, 0, 1)  # a: 64x512  b: 512x64
        dist = torch.maximum(
            norm_a - 2.0 * torch.matmul(a, b) + norm_b,
            self.zero_value)
        return dist
    
    def forward(self, x):
        # x: (N, S, E)
        # method: 1
        if self.m == 1:
            # x = torch.transpose(x, 1, 2)
            sims_list = []
            for i in range(x.shape[0]):
                sims_list.append(self.pairwise_l2_distance(x[i]))
            sims = torch.stack(sims_list)
        else:
            # method: 2
            sims = self.calc_sims(x)
        
        sims = F.softmax(-sims/self.temperature, dim=-1)
        sims = sims.unsqueeze(1)
        return F.relu(sims) # (N, 1, S, S)

In [None]:
TSM = TemproalSelfMatrix(m=2)
tsm_outputs = TSM(gmp_outputs)
tsm_outputs.shape

In [None]:
def show_hotmap(data):
    fig = plt.figure(figsize=(12, 12))
    
    xlabel = range(1, data.shape[0] + 1)
    ylabel = range(1, data.shape[0] + 1)
    ax = fig.add_subplot(111)
    
    ax.set_xticks(range(len(xlabel)))
    ax.set_xticklabels(xlabel, rotation=90)
    ax.xaxis.set_ticks_position('top') 
    
    ax.set_yticks(range(len(ylabel)))
    ax.set_yticklabels(ylabel)
    
    im = ax.imshow(data, cmap=plt.cm.PuBu)
    plt.colorbar(im)
tsm_hotmap_data = tsm_outputs[0][0].detach().numpy()
show_hotmap(tsm_hotmap_data)

### Period Predictor

#### Features Projection

In [None]:
TSM_Features = nn.Sequential(
    nn.Conv2d(in_channels=1,
              out_channels=32,
              kernel_size=3,
              padding=1),
    nn.BatchNorm2d(32),
    nn.ReLU(),
    nn.Dropout(p=0.25))

tsm_f_outputs = TSM_Features(tsm_outputs)
tsm_f_outputs.shape

In [None]:
class FeaturesProjection(nn.Module):
    def __init__(self, num_frames=64, out_features=512):
        super().__init__()
        self.projection = nn.Sequential(
            nn.Linear(num_frames*32, out_features),
            nn.ReLU(),
            nn.LayerNorm(out_features))

    def forward(self, x):
        # [N, 32, S, S] -> [N, S, S, 32]
        x = x.permute(0, 2, 3, 1)
        x = x.reshape(x.size(0), x.size(1), -1) # N, S, 32*S
        x = self.projection(x) # N, S, 512
        return x

In [None]:
TSM_FP = FeaturesProjection()
tsm_fp_outputs = TSM_FP(tsm_f_outputs)
tsm_fp_outputs.shape

#### Transformer Encoder

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1) # S, 1, d_model
        self.register_buffer('pe', pe)


    def forward(self, x):
        # (S, N, E:512)
        x = x + self.pe[:x.size(0), :, :]
        return self.dropout(x)
    
class TransformerModel(nn.Module):
    def __init__(self, num_frames=64, d_model=512, 
                 n_head=4, dim_ff=512, dropout=0.2,
                 num_layers=2, m=2):
        super().__init__()
        self.m = m
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_head,
            dim_feedforward=dim_ff,
            dropout=dropout,
            activation='relu')
        encoder_norm = nn.LayerNorm(d_model)
        if m == 1:
            self.pos_encoder = PositionalEncoding(d_model, dropout, num_frames)
        else:
            self.pos_encoder = torch.empty(1, num_frames, 1).normal_(mean=0, std=0.02)
            self.pos_encoder.requires_grad = True

        self.trans_encoder = nn.TransformerEncoder(encoder_layer, num_layers, encoder_norm)
                
    def forward(self, x):
        # [N, S, E]
        if self.m == 1:
            x = x.transpose(0, 1)
            x = self.pos_encoder(x) # S, N, D_Model
            x = self.trans_encoder(x)
            x = x.transpose(0, 1)
        else:
            x += self.pos_encoder # N, S, D_Model
            x = self.trans_encoder(x)
        return x

In [None]:
TE = TransformerModel(NUM_FRAMES, d_model=NUM_DMODEL, n_head=4, dropout=0.2, dim_ff=512, m=1)
te_outputs = TE(tsm_fp_outputs)
te_outputs.shape

#### Period Classifier

In [None]:
class PeriodClassifier(nn.Module):
    def __init__(self, num_frames=64, in_features=512, out_features=1):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.25),
            nn.Linear(in_features=in_features, out_features=512),
            # nn.LayerNorm(512),
            nn.ReLU(),
            nn.Dropout(p=0.25),
            nn.Linear(in_features=512, out_features=num_frames//2),
            nn.ReLU(),
            nn.Linear(in_features=num_frames//2, out_features=out_features),
            nn.ReLU())

    def forward(self, x):
        x = self.classifier(x)
        return x

In [None]:
pc = PeriodClassifier(NUM_FRAMES)
pc_outputs = pc(te_outputs)
pc_outputs.shape
pc

In [None]:
for name, _ in pc.named_parameters():
    print(name)

### Make All Together

In [None]:
class RepNet(nn.Module):
    def __init__(self, num_frames=64, num_dmodel=512):
        super().__init__()
        # Encoder
        self.resnet50 = ResNet50Base5D(pretrained=True)
        self.tcxt = TemporalContext()
        self.maxpool = GlobalMaxPool(m=1)
        # TSM
        self.tsm = TemproalSelfMatrix(num_frames=num_frames, temperature=13.544, m=1)  # noqa

        # Period Predictor
        self.tsm_features = nn.Sequential(
            nn.Conv2d(in_channels=1,
                      out_channels=32,
                      kernel_size=3,
                      padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout(p=0.25))

        self.projection1 = FeaturesProjection(num_frames=num_frames, out_features=num_dmodel)
        self.projection2 = FeaturesProjection(num_frames=num_frames, out_features=num_dmodel)

        # period length prediction
        self.trans1 = TransformerModel(
                num_frames, d_model=num_dmodel, n_head=4,
                dropout=0.25, dim_ff=num_dmodel, m=1)

        self.pc1 = PeriodClassifier(num_frames, num_dmodel)
        # periodicity prediction
        self.trans2 = TransformerModel(
                num_frames, d_model=num_dmodel, n_head=4,
                dropout=0.25, dim_ff=num_dmodel, m=1)
        self.pc2 = PeriodClassifier(num_frames, num_dmodel)

    def forward(self, x, retsim=False):
        x = self.resnet50(x)  # [N, 64, 1024, 7, 7]
        x = self.tcxt(x)  # [N, 64, 512, 7, 7]
        x = self.maxpool(x)  # [N, 64, 512]
        x = self.tsm(x)  # [N, 1, 64, 64]
        if retsim:
            z = x

        x = self.tsm_features(x)  # [N, 32, 64, 64]

        x1 = self.projection1(x)
        x2 = self.projection2(x)

        y1 = self.pc1(self.trans1(x1))  # L
        y2 = self.pc2(self.trans2(x2))  # P
        if retsim:
            return y1, y2, z
        else:
            return y1, y2

In [None]:
repnet = RepNet()
y1, y2 = repnet(X)
y1.shape, y2.shape

In [None]:
for name, weight in repnet.named_parameters():
    print(name)

## Train

In [None]:
train_dataset = CountixDataset(DATASET_PREFIX, 'train')
train_loader = DataLoader(train_dataset, batch_size=8, num_workers=4, shuffle=True, drop_last=True)

valid_dataset = CountixDataset(DATASET_PREFIX, 'val')
valid_loader = DataLoader(valid_dataset, batch_size=8, num_workers=4, shuffle=False)

test_dataset = CountixDataset(DATASET_PREFIX, 'test')
test_loader = DataLoader(valid_dataset, batch_size=8, num_workers=1, shuffle=False)

In [None]:
ckpt_path = f'{DATASET_PREFIX}/repnet5.pt'
device = torch.device("cuda")
model = RepNet(NUM_FRAMES, NUM_DMODEL).to(device)

optimizer = O.Adam(model.parameters(), lr=0.0001)
# optimizer = O.SGD(model.parameters(), lr=lr)
# scheduler = O.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.9)
# scheduler = O.lr_scheduler.StepLR(optimizer=optimizer, step_size=5, gamma=0.9)
# scheduler = O.lr_scheduler.MultiStepLR(optimizer, milestones=[
#         3, 10, 50, 100, 200, 300, 400], gamma=0.6)
scheduler = O.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, min_lr=1e-6)
criterions = [nn.SmoothL1Loss(), nn.BCEWithLogitsLoss()]

In [None]:
def train(device, model, pbar, optimizer, criterions, metrics_callback=None):
    model.train()
    loss_list = []
    for X, y1, y2, _ in pbar:
        X, y1, y2 = X.to(device), y1.to(device), y2.to(device)
        y1_pred, y2_pred = model(X)

        loss1 = criterions[0](y1_pred, y1)
        loss2 = criterions[1](y2_pred, y2)

        # count error
        y3_pred = torch.sum((y2_pred > 0) / (y1_pred + 1e-1), 1)
        y3_calc = torch.sum((y2 > 0) / (y1 + 1e-1), 1)
        loss3 = criterions[0](y3_pred, y3_calc)

        loss = 3*loss1 + 5*loss2 + 2*loss3

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_list.append(loss.item())

        nn.utils.clip_grad_value_(model.parameters(), clip_value=1.0)

        if metrics_callback is not None:
            metrics_callback(
                '%.3f' % np.mean(loss_list),
                '%.3f' % loss1.item(),
                '%.3f' % loss2.item(),
                '%.3f' % loss3.item())
            
        del X, y1, y2, y1_pred, y2_pred
    return np.mean(loss_list)


def valid(device, model, pbar, criterions, metrics_callback=None):
    model.eval()
    loss_list = []
    with torch.no_grad():
        for X, y1, y2, _ in pbar:
            X, y1, y2 = X.to(device), y1.to(device), y2.to(device)
            y1_pred, y2_pred = model(X)

            loss1 = criterions[0](y1_pred, y1)
            loss2 = criterions[1](y2_pred, y2)

            y3_pred = torch.sum((y2_pred > 0) / (y1_pred + 1e-1), 1)
            y3_calc = torch.sum((y2 > 0) / (y1 + 1e-1), 1)
            loss3 = criterions[0](y3_pred, y3_calc)

            loss = 3*loss1 + 5*loss2 + 2*loss3
            
            loss_list.append(loss.item())

            if metrics_callback is not None:
                metrics_callback(
                    '%.3f' % np.mean(loss_list),
                    '%.3f' % loss1.item(),
                    '%.3f' % loss2.item(),
                    '%.3f' % loss3.item())

            del X, y1, y2, y1_pred, y2_pred
    return np.mean(loss_list)


def inference(device, model, pbar, metrics_callback=None):
    # TODO only one test
    model.eval()
    with torch.no_grad():
        for X, y1, y2, y3_true in pbar:
            X, y1, y2 = X.to(device), y1.to(device), y2.to(device)
            y1_pred, y2_pred = model(X)

            y3_pred = torch.round(torch.sum((y2_pred > 0) / (y1_pred + 1e-1), 1))
            y3_calc = torch.round(torch.sum((y2 > 0) / (y1 + 1e-1), 1))

            if metrics_callback is not None:
                metrics_callback(
                        y3_pred.cpu().numpy().flatten().astype(int).tolist()[:8],
                        y3_calc.cpu().numpy().flatten().astype(int).tolist()[:8],
                        y3_true.numpy().flatten().astype(int).tolist()[:8])

            break


def train_loop(num_epochs, model, ckpt_path,
               train_loader, valid_loader, test_loader,
               optimizer, scheduler, criterions, device):

    start_epoch = 0
    fmode = 'w+'

    # load model
    if os.path.exists(ckpt_path):
        checkpoint = torch.load(ckpt_path)
        model.load_state_dict(checkpoint['model_state_dict'], strict=True)
        # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        # for state in optimizer.state.values():
        #     for k, v in state.items():
        #         if isinstance(v, torch.Tensor):
        #             state[k] = v.to(device)

        start_epoch = checkpoint['epoch'] + 1
        fmode = 'a+'

    # metrics log
    metrics_writer = open(f'{DATASET_PREFIX}/metrics.txt', fmode)

    lr = optimizer.param_groups[0]['lr']

    for epoch in tqdm(range(start_epoch, num_epochs + start_epoch), ascii=True):
        # train
        torch.cuda.empty_cache()
        with tqdm(train_loader, total=len(train_loader), desc='train', ascii=True) as pbar:
            train_loss = train(device, model, pbar, optimizer, criterions,
                         lambda loss, loss1, loss2, loss3: pbar.set_postfix(
                             epoch=epoch, lr=lr, loss=loss, loss1=loss1, loss2=loss2, loss3=loss3))

            metrics_writer.write('{}\n'.format(pbar))

        # valid
        torch.cuda.empty_cache()
        with tqdm(valid_loader, desc='valid', ascii=True) as pbar:
            valid_loss = valid(device, model, pbar, criterions,
                         lambda loss, loss1, loss2, loss3: pbar.set_postfix(
                             epoch=epoch, lr=lr, loss=loss, loss1=loss1, loss2=loss2, loss3=loss3))
            metrics_writer.write('{}\n'.format(pbar))

        # inference
        torch.cuda.empty_cache()
        with tqdm(test_loader, desc='inference test', ascii=True) as pbar:
            inference(device, model, pbar,
                      lambda y_pred, y_calc, y_true: pbar.set_postfix(
                          y_pred=y_pred, y_calc=y_calc, y_true=y_true))
            metrics_writer.write('{}\n'.format(pbar))
        with tqdm(train_loader, desc='inference train', ascii=True) as pbar:
            inference(device, model, pbar,
                      lambda y_pred, y_calc, y_true: pbar.set_postfix(
                          y_pred=y_pred, y_calc=y_calc, y_true=y_true))
            metrics_writer.write('{}\n'.format(pbar))

        # update learning rate
        if isinstance(scheduler, O.lr_scheduler.ReduceLROnPlateau):
            scheduler.step(valid_loss)
            lr = '%.7f' % scheduler._last_lr[0]
        else:
            scheduler.step()
            lr = '%.7f' % scheduler.get_last_lr()[0]

        metrics_writer.flush()

        # save model
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'valid_loss': valid_loss,
        }
        torch.save(checkpoint, ckpt_path)

    metrics_writer.close()
    
train_loop(1000, model, '/data/last3_0.0001000.pt',
           train_loader, valid_loader, test_loader,
           optimizer, scheduler, criterions, device)

## References

1. https://arxiv.org/pdf/2006.15418.pdf

2. https://colab.research.google.com/github/google-research/google-research/blob/master/repnet/repnet_colab.ipynb#scrollTo=76L5XFonl_Bw

## Other

In [None]:
import re

fm_int = '[-+]?\d+'
fm_flt = '[-+]?[0-9]+\.[0-9]+'

train_mdata_list = []
valid_mdata_list = []

def parse_log(mdata_list, phase, log):
    p_epoch = 'epoch=(?P<epoch>%s)' % fm_int
    p_loss = 'loss=(?P<loss>%s)' % fm_flt
    p_loss1 = 'loss1=(?P<loss1>%s)' % fm_flt
    p_loss2 = 'loss2=(?P<loss2>%s)' % fm_flt
    p_loss3 = 'loss3=(?P<loss3>%s)' % fm_flt
    p_lr = 'lr=(?P<lr>.*[^\]])'
    resdata = re.search(r'%s: .*, %s, %s, %s, %s, %s, %s' % (
        phase, p_epoch, p_loss, p_loss1, p_loss2, p_loss3, p_lr), log)

    if resdata:
        grpdata = resdata.groupdict()
        mdata_list.append({
            'epoch': int(grpdata['epoch']),
            'loss': float(grpdata['loss']),
            'loss1': float(grpdata['loss1']),
            'loss2': float(grpdata['loss2']),
            'loss3': float(grpdata['loss3']),
            'lr': float(grpdata['lr'])
        })

with open('/data/metrics3.txt') as fr:
    for line in fr.read().split('\n'):
        parse_log(train_mdata_list, 'train', line)
        parse_log(valid_mdata_list, 'valid', line)
        
fig = go.Figure()

fig.add_trace(go.Scatter(
    x = [x['epoch'] for x in train_mdata_list],
    y = [x['loss'] for x in train_mdata_list],
    text = ['lr: %.6f' % x['lr'] for x in valid_mdata_list],
    mode = 'lines',
    name = 'train loss'
))

fig.add_trace(go.Scatter(
    x = [x['epoch'] for x in valid_mdata_list],
    y = [x['loss'] for x in valid_mdata_list],
    text = ['lr: %.6f' % x['lr'] for x in valid_mdata_list],
    mode = 'lines',
    name = 'valid loss'
))

fig